Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@ ORDER BY s.from_date;


### 后续TODO
- [x] 支持通过自然语言生成SQL(大模型)
- [ ] 建立字段关系,同步关联更新
- [ ] 支持通过自然语言生成SQL(大模型)
- [ ] 支持分库分表
28 changes: 28 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,42 @@
<tag/>
<url/>
</scm>

<properties>
<java.version>21</java.version>
<spring-ai.version>1.0.0-M6</spring-ai.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>

<!--
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
</dependency>
-->

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,19 @@ public Result<List<String>> getAllSchema() {
}

public Result<List<String>> getSchemaTables(String schema) {
Validate.notBlank(schema, "schema can not be blank");

final Optional<DataSource> sourceOptional = dataSourceRegistry.getDataSource(schema);
if (sourceOptional.isEmpty()) {
return Result.success(List.of());
final Result<List<TableColumnInfo>> tableColumnInfoResult = getTableColumnInfos(schema);
if (!tableColumnInfoResult.isSuccess()) {
return Result.fail(tableColumnInfoResult.getMessage());
}
final DataSource dataSource = sourceOptional.get();

final List<TableColumnInfo> tableColumnMetaInfo = mySqlTableMetaInfoQuery.getTableColumnMetaInfo(schema, dataSource);
var tableColumnMetaInfo = tableColumnInfoResult.getData();
final List<String> tableNameList = tableColumnMetaInfo.stream().map(TableColumnInfo::tableName)
.sorted()
.toList();
return Result.success(tableNameList);
}

public Result<List<String>> getTableColumns(String schema, String tableName) {

final TableColumnInfo tableColumnMetaInfo = getTableColumnInfo(schema, tableName);

// 列就不进行排序了,保留原始顺序,便于页面查看
Expand All @@ -66,6 +62,25 @@ public Result<List<String>> getTableColumns(String schema, String tableName) {
return Result.success(columnNames);
}

public Result<List<TableColumnInfo>> getTableColumnInfos(String schema) {
Validate.notBlank(schema, "schema can not be blank");

final Optional<DataSource> sourceOptional = dataSourceRegistry.getDataSource(schema);
if (sourceOptional.isEmpty()) {
return Result.success(List.of());
}
final DataSource dataSource = sourceOptional.get();

return Result.success(mySqlTableMetaInfoQuery.getTableColumnMetaInfo(schema, dataSource));
}

public Result<List<Map<String, Object>>> executeSql(String schema, String sql) {
Validate.notBlank(schema, "schema can not be blank");
Validate.notBlank(sql, "sql can not be blank");
final List<Map<String, Object>> maps = dataQuery.queryBySql(schema, sql);
return Result.success(maps);
}


public Result<TableData> queryTableData(QueryCondition queryCondition) {
checkParam(queryCondition);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.github.zavier.table.relation.service;

import com.github.zavier.table.relation.service.domain.TableColumnInfo;
import com.github.zavier.table.relation.service.dto.Result;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.jetbrains.annotations.Nullable;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Service
public class SqlGenerateService {

private final ChatClient chatClient;

private final DataQueryService dataQueryService;
private final RelationManagerService relationManagerService;

public SqlGenerateService(ChatClient.Builder chatClientBuilder,
DataQueryService dataQueryService,
RelationManagerService relationManagerService) {
this.chatClient = chatClientBuilder
.defaultAdvisors(new SimpleLoggerAdvisor())
.build();
this.dataQueryService = dataQueryService;
this.relationManagerService = relationManagerService;
}

public String generateSql(String schema, String demand) {
// 考虑使用多个表?
final String useTable = findRelatedUseTable(schema, demand);

final String erDiagram = relationManagerService.getTableRelationMermaidERDiagram(schema, useTable, true);

return generateSqlByErDiagram(demand, erDiagram);
}

private String findRelatedUseTable(String schema, String demand) {
final Result<List<TableColumnInfo>> tableColumnInfos = dataQueryService.getTableColumnInfos(schema);
Validate.isTrue(tableColumnInfos.isSuccess(), "查询表信息失败");
Validate.isTrue(!tableColumnInfos.getData().isEmpty(), "表信息不存在");

// 生成表信息字符串
final List<TableColumnInfo> data = tableColumnInfos.getData();
final String tableStr = data.stream()
.map(it -> {
final String comment = it.tableComment();
if (StringUtils.isNotBlank(comment)) {
return it.tableName() + "(" + comment + ")";
}
return it.tableName();
}).collect(Collectors.joining(","));

// 查询最可能使用的表 TODO 这里可以进行一下表名校验&重试
return this.chatClient.prompt()
.user(u -> u.text("""
根据如下用户的数据库表查询的【用户需求】,在【可选择的表名】选择如下表中最可能使用到的**一个**表是哪个?**只需要返回返回可能性最大的一个表名**
用户需求: {demand}
可选择的表名: {tables}
""").params(Map.of("demand", demand, "tables", tableStr)))
.call()
.entity(String.class);
}

private @Nullable String generateSqlByErDiagram(String demand, String erDiagram) {
return this.chatClient.prompt()
.user(u -> u.text("""
根据如下【mermaid格式的ER图】,为这个**MySQL**数据库表查询的【用户需求】,生成对应的查询SQL语句,只需要返回SQL语句
mermaid格式ER图: {erDiagram}
用户需求: {demand}
""").params(Map.of("demand", demand, "erDiagram", erDiagram)))
.call()
.entity(String.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@ public class MySqlTableMetaInfoQuery {
private SqlExecutor sqlExecutor;

public List<TableColumnInfo> getTableColumnMetaInfo(String schema, DataSource dataSource) {
String sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ?";
final List<Map<String, Object>> maps = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema);
return TableInfoConverter.convert2TableColumnInfo(maps);
String columnSql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ?";
final List<Map<String, Object>> columnDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, columnSql, schema);

String tableSql = "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ?";
final List<Map<String, Object>> tableDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, tableSql, schema);
return TableInfoConverter.convert2TableColumnInfo(columnDataList, tableDataList);
}

public TableColumnInfo getTableColumnMetaInfo(String schema, String table, DataSource dataSource) {
String sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?";
final List<Map<String, Object>> maps = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema, table);
final List<TableColumnInfo> tableColumnInfos = TableInfoConverter.convert2TableColumnInfo(maps);
final List<Map<String, Object>> columnDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, sql, schema, table);

String tableSql = "SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?";
final List<Map<String, Object>> tableDataList = sqlExecutor.sqlQueryWithoutLimit(dataSource, tableSql, schema, table);
final List<TableColumnInfo> tableColumnInfos = TableInfoConverter.convert2TableColumnInfo(columnDataList, tableDataList);

Validate.notEmpty(tableColumnInfos, "tableColumnInfos is empty");
Validate.isTrue(tableColumnInfos.size() == 1, "tableColumnInfos size is not 1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.github.zavier.table.relation.service.domain.ColumnInfo;
import com.github.zavier.table.relation.service.domain.TableColumnInfo;
import com.github.zavier.table.relation.service.domain.TableInfo;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -11,42 +10,34 @@

public class TableInfoConverter {

public static List<TableColumnInfo> convert2TableColumnInfo(List<Map<String, Object>> dataMapList) {
if (dataMapList.isEmpty()) {
public static List<TableColumnInfo> convert2TableColumnInfo(List<Map<String, Object>> columnDataList,
List<Map<String, Object>> tableDataList) {
if (columnDataList.isEmpty() || tableDataList.isEmpty()) {
return List.of();
}

List<TableColumnInfo> result = new ArrayList<>();

final Map<String, List<Map<String, Object>>> tableColumnsMap = dataMapList.stream()
final Map<String, List<Map<String, Object>>> tableColumnsMap = columnDataList.stream()
.collect(Collectors.groupingBy((dataMap) -> dataMap.get("TABLE_SCHEMA") + (String) dataMap.get("TABLE_NAME")));

final Map<String, String> tableKeyCommentMap = tableDataList.stream()
.collect(Collectors.toMap((dataMap) -> dataMap.get("TABLE_SCHEMA") + (String) dataMap.get("TABLE_NAME"), (dataMap) -> (String) (dataMap.get("TABLE_COMMENT")),
(v1, v2) -> v1));

tableColumnsMap.forEach((key, columnInfos) -> {
final List<ColumnInfo> columnInfoList = columnInfos.stream().map(TableInfoConverter::convert2ColumnInfo).toList();
final Map<String, Object> column = columnInfos.get(0);
final TableColumnInfo tableColumnInfo = new TableColumnInfo(
(String) column.get("TABLE_SCHEMA"),
(String) column.get("TABLE_NAME"),
"", // 暂未设置
tableKeyCommentMap.getOrDefault(key, ""),
columnInfoList
);
result.add(tableColumnInfo);
});
return result;
}

public static List<TableInfo> convert2TableBaseInfo(List<Map<String, Object>> dataMapList) {
if (dataMapList.isEmpty()) {
return List.of();
}
return dataMapList.stream().map(dataMap -> {
return new TableInfo(
(String) dataMap.get("TABLE_SCHEMA"),
(String) dataMap.get("TABLE_NAME"),
(String) dataMap.get("TABLE_COMMENT")
);
}).toList();
}

public static ColumnInfo convert2ColumnInfo(Map<String, Object> dataMap) {
return new ColumnInfo(
(String) dataMap.get("COLUMN_NAME"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.github.zavier.table.relation.service.dto;

public class ExecuteSqlDto {
private String sql;
private String schema;

public String getSql() {
return sql;
}

public void setSql(String sql) {
this.sql = sql;
}

public String getSchema() {
return schema;
}

public void setSchema(String schema) {
this.schema = schema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ public Map<String, Map<String, List<Map<String, Object>>>> query(QueryCondition
return queryByBfs(queryCondition);
}

public List<Map<String, Object>> queryBySql(String schema, String sql) {
final Optional<DataSource> sourceOptional = dataSourceRegistry.getDataSource(schema);
if (sourceOptional.isEmpty()) {
throw new RuntimeException("dataSource not found:" + schema);
}
final DataSource dataSource = sourceOptional.get();
return sqlExecutor.sqlQueryWithoutLimit(dataSource, sql);
}

private Map<String, Map<String, List<Map<String, Object>>>> queryByBfs(QueryCondition queryCondition) {
Queue<QueryCondition> queue = new LinkedList<>();
queue.add(queryCondition);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.github.zavier.table.relation.service.DataQueryService;
import com.github.zavier.table.relation.service.RelationManagerService;
import com.github.zavier.table.relation.service.SqlGenerateService;
import com.github.zavier.table.relation.service.dto.ExecuteSqlDto;
import com.github.zavier.table.relation.service.dto.QueryCondition;
import com.github.zavier.table.relation.service.dto.Result;
import com.github.zavier.table.relation.service.dto.TableData;
Expand All @@ -11,6 +13,7 @@
import org.springframework.web.bind.annotation.*;

import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/api/table")
Expand All @@ -21,6 +24,8 @@ public class DataController {
private DataQueryService dataQueryService;
@Resource
private RelationManagerService relationManagerService;
@Resource
private SqlGenerateService sqlGenerateService;

@GetMapping("/allSchema")
public Result<List<String>> getAllSchema() {
Expand All @@ -37,6 +42,17 @@ public Result<List<String>> getTableColumns(@RequestParam("schema") String schem
return dataQueryService.getTableColumns(schema, tableName);
}

@GetMapping("/generateSql")
public Result<String> generateSql(@RequestParam("schema") String schema, @RequestParam("demand") String demand) {
return Result.success(sqlGenerateService.generateSql(schema, demand));
}

// executeSql
@PostMapping("/executeSql")
public Result<List<Map<String, Object>>> executeSql(@RequestBody ExecuteSqlDto executeSqlDto) {
return dataQueryService.executeSql(executeSqlDto.getSchema(), executeSqlDto.getSql());
}

@PostMapping("/sqlQuery")
public Result<TableData> queryRelaData(@RequestBody QueryCondition queryCondition) {
return dataQueryService.queryTableData(queryCondition);
Expand Down
17 changes: 16 additions & 1 deletion src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,19 @@ spring.sql.init.mode=always


# logic no-delete condition
logic.no.delete.condition=
logic.no.delete.condition=


# deepseek
spring.ai.openai.api-key=INSERT-DEEPSEEK-API-KEY-HERE
spring.ai.openai.base-url=https://api.deepseek.com
spring.ai.openai.chat.options.model=deepseek-chat
spring.ai.openai.chat.options.temperature=1
# The DeepSeek API doesn't support embeddings, so we need to disable it.
spring.ai.openai.embedding.enabled=false

# qianfan
#spring.ai.qianfan.base-url=
spring.ai.qianfan.api-key=
spring.ai.qianfan.secret-key=
#spring.ai.qianfan.chat.options.model=ernie_speed
2 changes: 2 additions & 0 deletions src/main/resources/logback-spring.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

<!-- <logger name="org.springframework.jdbc.core.JdbcTemplate" level="DEBUG"/>-->

<logger name="org.springframework.ai.chat.client.advisor" level="DEBUG"/>

<root level="INFO">
<appender-ref ref="FILE"/>
<appender-ref ref="CONSOLE"/>
Expand Down