diff --git a/README.md b/README.md index 108f0c5..65b1a55 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,6 @@ ORDER BY s.from_date; ### 后续TODO +- [x] 支持通过自然语言生成SQL(大模型) - [ ] 建立字段关系,同步关联更新 -- [ ] 支持通过自然语言生成SQL(大模型) - [ ] 支持分库分表 diff --git a/pom.xml b/pom.xml index 64e7fc9..aefbfad 100644 --- a/pom.xml +++ b/pom.xml @@ -26,14 +26,42 @@ + 21 + 1.0.0-M6 + + + + + org.springframework.ai + spring-ai-bom + ${spring-ai.version} + pom + import + + + + org.springframework.boot spring-boot-starter-web + + + org.springframework.ai + spring-ai-openai-spring-boot-starter + + + + org.springframework.boot spring-boot-starter-test diff --git a/src/main/java/com/github/zavier/table/relation/service/DataQueryService.java b/src/main/java/com/github/zavier/table/relation/service/DataQueryService.java index 6c73cb0..81b4f3b 100644 --- a/src/main/java/com/github/zavier/table/relation/service/DataQueryService.java +++ b/src/main/java/com/github/zavier/table/relation/service/DataQueryService.java @@ -37,15 +37,12 @@ public Result> getAllSchema() { } public Result> getSchemaTables(String schema) { - Validate.notBlank(schema, "schema can not be blank"); - - final Optional sourceOptional = dataSourceRegistry.getDataSource(schema); - if (sourceOptional.isEmpty()) { - return Result.success(List.of()); + final Result> tableColumnInfoResult = getTableColumnInfos(schema); + if (!tableColumnInfoResult.isSuccess()) { + return Result.fail(tableColumnInfoResult.getMessage()); } - final DataSource dataSource = sourceOptional.get(); - final List tableColumnMetaInfo = mySqlTableMetaInfoQuery.getTableColumnMetaInfo(schema, dataSource); + var tableColumnMetaInfo = tableColumnInfoResult.getData(); final List tableNameList = tableColumnMetaInfo.stream().map(TableColumnInfo::tableName) .sorted() .toList(); @@ -53,7 +50,6 @@ public Result> getSchemaTables(String schema) { } public Result> getTableColumns(String schema, String tableName) { - final TableColumnInfo tableColumnMetaInfo = getTableColumnInfo(schema, tableName); // 列就不进行排序了,保留原始顺序,便于页面查看 @@ -66,6 +62,25 @@ public Result> getTableColumns(String schema, String tableName) { return Result.success(columnNames); } + public Result> getTableColumnInfos(String schema) { + Validate.notBlank(schema, "schema can not be blank"); + + final Optional sourceOptional = dataSourceRegistry.getDataSource(schema); + if (sourceOptional.isEmpty()) { + return Result.success(List.of()); + } + final DataSource dataSource = sourceOptional.get(); + + return Result.success(mySqlTableMetaInfoQuery.getTableColumnMetaInfo(schema, dataSource)); + } + + public Result>> executeSql(String schema, String sql) { + Validate.notBlank(schema, "schema can not be blank"); + Validate.notBlank(sql, "sql can not be blank"); + final List> maps = dataQuery.queryBySql(schema, sql); + return Result.success(maps); + } + public Result queryTableData(QueryCondition queryCondition) { checkParam(queryCondition); diff --git a/src/main/java/com/github/zavier/table/relation/service/SqlGenerateService.java b/src/main/java/com/github/zavier/table/relation/service/SqlGenerateService.java new file mode 100644 index 0000000..2c9727d --- /dev/null +++ b/src/main/java/com/github/zavier/table/relation/service/SqlGenerateService.java @@ -0,0 +1,80 @@ +package com.github.zavier.table.relation.service; + +import com.github.zavier.table.relation.service.domain.TableColumnInfo; +import com.github.zavier.table.relation.service.dto.Result; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.Validate; +import org.jetbrains.annotations.Nullable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Service +public class SqlGenerateService { + + private final ChatClient chatClient; + + private final DataQueryService dataQueryService; + private final RelationManagerService relationManagerService; + + public SqlGenerateService(ChatClient.Builder chatClientBuilder, + DataQueryService dataQueryService, + RelationManagerService relationManagerService) { + this.chatClient = chatClientBuilder + .defaultAdvisors(new SimpleLoggerAdvisor()) + .build(); + this.dataQueryService = dataQueryService; + this.relationManagerService = relationManagerService; + } + + public String generateSql(String schema, String demand) { + // 考虑使用多个表? + final String useTable = findRelatedUseTable(schema, demand); + + final String erDiagram = relationManagerService.getTableRelationMermaidERDiagram(schema, useTable, true); + + return generateSqlByErDiagram(demand, erDiagram); + } + + private String findRelatedUseTable(String schema, String demand) { + final Result> tableColumnInfos = dataQueryService.getTableColumnInfos(schema); + Validate.isTrue(tableColumnInfos.isSuccess(), "查询表信息失败"); + Validate.isTrue(!tableColumnInfos.getData().isEmpty(), "表信息不存在"); + + // 生成表信息字符串 + final List data = tableColumnInfos.getData(); + final String tableStr = data.stream() + .map(it -> { + final String comment = it.tableComment(); + if (StringUtils.isNotBlank(comment)) { + return it.tableName() + "(" + comment + ")"; + } + return it.tableName(); + }).collect(Collectors.joining(",")); + + // 查询最可能使用的表 TODO 这里可以进行一下表名校验&重试 + return this.chatClient.prompt() + .user(u -> u.text(""" + 根据如下用户的数据库表查询的【用户需求】,在【可选择的表名】选择如下表中最可能使用到的**一个**表是哪个?**只需要返回返回可能性最大的一个表名** + 用户需求: {demand} + 可选择的表名: {tables} + """).params(Map.of("demand", demand, "tables", tableStr))) + .call() + .entity(String.class); + } + + private @Nullable String generateSqlByErDiagram(String demand, String erDiagram) { + return this.chatClient.prompt() + .user(u -> u.text(""" + 根据如下【mermaid格式的ER图】,为这个**MySQL**数据库表查询的【用户需求】,生成对应的查询SQL语句,只需要返回SQL语句 + mermaid格式ER图: {erDiagram} + 用户需求: {demand} + """).params(Map.of("demand", demand, "erDiagram", erDiagram))) + .call() + .entity(String.class); + } +} diff --git a/src/main/java/com/github/zavier/table/relation/service/abilty/MySqlTableMetaInfoQuery.java b/src/main/java/com/github/zavier/table/relation/service/abilty/MySqlTableMetaInfoQuery.java index c0e1a84..68617b9 100644 --- a/src/main/java/com/github/zavier/table/relation/service/abilty/MySqlTableMetaInfoQuery.java +++ b/src/main/java/com/github/zavier/table/relation/service/abilty/MySqlTableMetaInfoQuery.java @@ -19,15 +19,21 @@ public class MySqlTableMetaInfoQuery { private SqlExecutor sqlExecutor; public List getTableColumnMetaInfo(String schema, DataSource dataSource) { - String sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ?"; - final List> maps = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema); - return TableInfoConverter.convert2TableColumnInfo(maps); + String columnSql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ?"; + final List> columnDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, columnSql, schema); + + String tableSql = "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ?"; + final List> tableDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, tableSql, schema); + return TableInfoConverter.convert2TableColumnInfo(columnDataList, tableDataList); } public TableColumnInfo getTableColumnMetaInfo(String schema, String table, DataSource dataSource) { String sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"; - final List> maps = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema, table); - final List tableColumnInfos = TableInfoConverter.convert2TableColumnInfo(maps); + final List> columnDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema, table); + + String tableSql = "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"; + final List> tableDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, tableSql, schema, table); + final List tableColumnInfos = TableInfoConverter.convert2TableColumnInfo(columnDataList, tableDataList); Validate.notEmpty(tableColumnInfos, "tableColumnInfos is empty"); Validate.isTrue(tableColumnInfos.size() == 1, "tableColumnInfos size is not 1"); diff --git a/src/main/java/com/github/zavier/table/relation/service/converter/TableInfoConverter.java b/src/main/java/com/github/zavier/table/relation/service/converter/TableInfoConverter.java index b4de8e5..7b2d42f 100644 --- a/src/main/java/com/github/zavier/table/relation/service/converter/TableInfoConverter.java +++ b/src/main/java/com/github/zavier/table/relation/service/converter/TableInfoConverter.java @@ -2,7 +2,6 @@ import com.github.zavier.table.relation.service.domain.ColumnInfo; import com.github.zavier.table.relation.service.domain.TableColumnInfo; -import com.github.zavier.table.relation.service.domain.TableInfo; import java.util.ArrayList; import java.util.List; @@ -11,22 +10,27 @@ public class TableInfoConverter { - public static List convert2TableColumnInfo(List> dataMapList) { - if (dataMapList.isEmpty()) { + public static List convert2TableColumnInfo(List> columnDataList, + List> tableDataList) { + if (columnDataList.isEmpty() || tableDataList.isEmpty()) { return List.of(); } - List result = new ArrayList<>(); - final Map>> tableColumnsMap = dataMapList.stream() + final Map>> tableColumnsMap = columnDataList.stream() .collect(Collectors.groupingBy((dataMap) -> dataMap.get("TABLE_SCHEMA") + (String) dataMap.get("TABLE_NAME"))); + + final Map tableKeyCommentMap = tableDataList.stream() + .collect(Collectors.toMap((dataMap) -> dataMap.get("TABLE_SCHEMA") + (String) dataMap.get("TABLE_NAME"), (dataMap) -> (String) (dataMap.get("TABLE_COMMENT")), + (v1, v2) -> v1)); + tableColumnsMap.forEach((key, columnInfos) -> { final List columnInfoList = columnInfos.stream().map(TableInfoConverter::convert2ColumnInfo).toList(); final Map column = columnInfos.get(0); final TableColumnInfo tableColumnInfo = new TableColumnInfo( (String) column.get("TABLE_SCHEMA"), (String) column.get("TABLE_NAME"), - "", // 暂未设置 + tableKeyCommentMap.getOrDefault(key, ""), columnInfoList ); result.add(tableColumnInfo); @@ -34,19 +38,6 @@ public static List convert2TableColumnInfo(List convert2TableBaseInfo(List> dataMapList) { - if (dataMapList.isEmpty()) { - return List.of(); - } - return dataMapList.stream().map(dataMap -> { - return new TableInfo( - (String) dataMap.get("TABLE_SCHEMA"), - (String) dataMap.get("TABLE_NAME"), - (String) dataMap.get("TABLE_COMMENT") - ); - }).toList(); - } - public static ColumnInfo convert2ColumnInfo(Map dataMap) { return new ColumnInfo( (String) dataMap.get("COLUMN_NAME"), diff --git a/src/main/java/com/github/zavier/table/relation/service/dto/ExecuteSqlDto.java b/src/main/java/com/github/zavier/table/relation/service/dto/ExecuteSqlDto.java new file mode 100644 index 0000000..44c706a --- /dev/null +++ b/src/main/java/com/github/zavier/table/relation/service/dto/ExecuteSqlDto.java @@ -0,0 +1,22 @@ +package com.github.zavier.table.relation.service.dto; + +public class ExecuteSqlDto { + private String sql; + private String schema; + + public String getSql() { + return sql; + } + + public void setSql(String sql) { + this.sql = sql; + } + + public String getSchema() { + return schema; + } + + public void setSchema(String schema) { + this.schema = schema; + } +} diff --git a/src/main/java/com/github/zavier/table/relation/service/query/DataQuery.java b/src/main/java/com/github/zavier/table/relation/service/query/DataQuery.java index 8c82346..f32c28d 100644 --- a/src/main/java/com/github/zavier/table/relation/service/query/DataQuery.java +++ b/src/main/java/com/github/zavier/table/relation/service/query/DataQuery.java @@ -35,6 +35,15 @@ public Map>>> query(QueryCondition return queryByBfs(queryCondition); } + public List> queryBySql(String schema, String sql) { + final Optional sourceOptional = dataSourceRegistry.getDataSource(schema); + if (sourceOptional.isEmpty()) { + throw new RuntimeException("dataSource not found:" + schema); + } + final DataSource dataSource = sourceOptional.get(); + return sqlExecutor.sqlQueryWithoutLimit(dataSource, sql); + } + private Map>>> queryByBfs(QueryCondition queryCondition) { Queue queue = new LinkedList<>(); queue.add(queryCondition); diff --git a/src/main/java/com/github/zavier/table/relation/web/DataController.java b/src/main/java/com/github/zavier/table/relation/web/DataController.java index 2a55a65..aa8381e 100644 --- a/src/main/java/com/github/zavier/table/relation/web/DataController.java +++ b/src/main/java/com/github/zavier/table/relation/web/DataController.java @@ -2,6 +2,8 @@ import com.github.zavier.table.relation.service.DataQueryService; import com.github.zavier.table.relation.service.RelationManagerService; +import com.github.zavier.table.relation.service.SqlGenerateService; +import com.github.zavier.table.relation.service.dto.ExecuteSqlDto; import com.github.zavier.table.relation.service.dto.QueryCondition; import com.github.zavier.table.relation.service.dto.Result; import com.github.zavier.table.relation.service.dto.TableData; @@ -11,6 +13,7 @@ import org.springframework.web.bind.annotation.*; import java.util.List; +import java.util.Map; @RestController @RequestMapping("/api/table") @@ -21,6 +24,8 @@ public class DataController { private DataQueryService dataQueryService; @Resource private RelationManagerService relationManagerService; + @Resource + private SqlGenerateService sqlGenerateService; @GetMapping("/allSchema") public Result> getAllSchema() { @@ -37,6 +42,17 @@ public Result> getTableColumns(@RequestParam("schema") String schem return dataQueryService.getTableColumns(schema, tableName); } + @GetMapping("/generateSql") + public Result generateSql(@RequestParam("schema") String schema, @RequestParam("demand") String demand) { + return Result.success(sqlGenerateService.generateSql(schema, demand)); + } + + // executeSql + @PostMapping("/executeSql") + public Result>> executeSql(@RequestBody ExecuteSqlDto executeSqlDto) { + return dataQueryService.executeSql(executeSqlDto.getSchema(), executeSqlDto.getSql()); + } + @PostMapping("/sqlQuery") public Result queryRelaData(@RequestBody QueryCondition queryCondition) { return dataQueryService.queryTableData(queryCondition); diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index bcaa596..ff83872 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -20,4 +20,19 @@ spring.sql.init.mode=always # logic no-delete condition -logic.no.delete.condition= \ No newline at end of file +logic.no.delete.condition= + + +# deepseek +spring.ai.openai.api-key=INSERT-DEEPSEEK-API-KEY-HERE +spring.ai.openai.base-url=https://api.deepseek.com +spring.ai.openai.chat.options.model=deepseek-chat +spring.ai.openai.chat.options.temperature=1 +# The DeepSeek API doesn't support embeddings, so we need to disable it. +spring.ai.openai.embedding.enabled=false + +# qianfan +#spring.ai.qianfan.base-url= +spring.ai.qianfan.api-key= +spring.ai.qianfan.secret-key= +#spring.ai.qianfan.chat.options.model=ernie_speed \ No newline at end of file diff --git a/src/main/resources/logback-spring.xml b/src/main/resources/logback-spring.xml index e8b54f3..925fa84 100644 --- a/src/main/resources/logback-spring.xml +++ b/src/main/resources/logback-spring.xml @@ -21,6 +21,8 @@ + +