From e670dee18d717f5e2cc01a61764bce36a40e8ca3 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 22 Apr 2026 17:55:20 +0800 Subject: [PATCH 01/22] feat: add sqltool localrelations --- .../datasource/DatasourceExplorerService.java | 176 ++- .../DatasourceExplorerToolProvider.java | 9 +- .../ai/dataagent/aop/ExceptionAdvice.java | 34 - .../ai/dataagent/bo/schema/ResultBO.java | 33 - .../dataagent/controller/EchoController.java | 38 - .../cloud/ai/dataagent/util/DatabaseUtil.java | 60 - .../cloud/ai/dataagent/util/DateTimeUtil.java | 1118 ----------------- 7 files changed, 156 insertions(+), 1312 deletions(-) delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DatabaseUtil.java delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DateTimeUtil.java diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index b207f0dac..07f8013a0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -17,6 +17,7 @@ import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ForeignKeyInfoBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ResultSetBO; import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; @@ -101,7 +102,8 @@ private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceE .stream() .sorted(String.CASE_INSENSITIVE_ORDER) .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables())) + context.explicitSelectedTables(), + context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()))) .limit(limit) .toList(); return baseResult(context, DatasourceExplorerAction.LIST_TABLES, tables.size() + " tables available") @@ -118,7 +120,8 @@ private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceE List> matchedTables = context.visibleTables() .stream() .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables())) + context.explicitSelectedTables(), + context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()))) .filter(table -> query.isEmpty() || containsQuery(table, query)) .limit(limit) .toList(); @@ -141,10 +144,12 @@ private DatasourceExplorerResult getTableSchema(ExplorerContext context, Datasou .setTable(tableName)); Document tableDocument = loadTableDocumentMap(context, List.of(tableName)).get(normalizeTableName(tableName)); List> columnEntries = columns.stream().map(this::toColumnEntry).toList(); - List> relationEntries = filterRelations(context.logicalRelations(), tableName).stream() + List relations = filterRelations(context, tableName); + List> relationEntries = relations.stream() .map(this::toRelationEntry) .toList(); - Map tableEntry = toTableEntry(tableName, tableDocument, context.explicitSelectedTables()); + Map tableEntry = toTableEntry(tableName, tableDocument, context.explicitSelectedTables(), + relations); return baseResult(context, DatasourceExplorerAction.GET_TABLE_SCHEMA, "Loaded schema for table '%s'".formatted(tableName)) .tables(List.of(tableEntry)) @@ -157,18 +162,18 @@ private DatasourceExplorerResult getTableSchema(ExplorerContext context, Datasou private DatasourceExplorerResult getRelatedTables(ExplorerContext context, DatasourceExplorerRequest request) { String tableName = requireSingleTableName(request); assertVisibleTable(context, tableName); - List relations = filterRelations(context.logicalRelations(), tableName); + List relations = filterRelations(context, tableName); List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); Set relatedTables = relations.stream() - .flatMap(relation -> Arrays - .stream(new String[] { relation.getSourceTableName(), relation.getTargetTableName() })) + .flatMap(relation -> Arrays.stream(new String[] { relation.sourceTable(), relation.targetTable() })) .filter(candidate -> !normalizeTableName(candidate).equals(normalizeTableName(tableName))) .filter(candidate -> context.visibleTableNameSet().contains(normalizeTableName(candidate))) .collect(Collectors.toCollection(LinkedHashSet::new)); Map tableDocumentMap = loadTableDocumentMap(context, new ArrayList<>(relatedTables)); List> tableEntries = relatedTables.stream() .map(relatedTable -> toTableEntry(relatedTable, tableDocumentMap.get(normalizeTableName(relatedTable)), - context.explicitSelectedTables())) + context.explicitSelectedTables(), + context.relationsByTable().getOrDefault(normalizeTableName(relatedTable), List.of()))) .toList(); return baseResult(context, DatasourceExplorerAction.GET_RELATED_TABLES, "Found %d related tables for '%s'".formatted(tableEntries.size(), tableName)) @@ -233,9 +238,23 @@ private ExplorerContext resolveContext(String agentId) throws Exception { .map(this::normalizeTableName) .collect(Collectors.toCollection(LinkedHashSet::new)); List logicalRelations = datasourceService.getLogicalRelations(datasource.getId()); + List physicalRelations = loadPhysicalRelations(accessor, dbConfig); + List unifiedRelations = buildUnifiedRelations(visibleTableNameSet, physicalRelations, + logicalRelations == null ? List.of() : logicalRelations); return new ExplorerContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), Set.copyOf(visibleTableNameSet), List.copyOf(explicitSelectedTables), - logicalRelations == null ? List.of() : List.copyOf(logicalRelations)); + List.copyOf(unifiedRelations), indexRelationsByTable(unifiedRelations)); + } + + private List loadPhysicalRelations(Accessor accessor, DbConfigBO dbConfig) { + try { + List foreignKeys = accessor.showForeignKeys(dbConfig, + DbQueryParameter.from(dbConfig).setSchema(dbConfig.getSchema())); + return foreignKeys == null ? List.of() : foreignKeys; + } + catch (Exception ex) { + return List.of(); + } } private Long parseAgentId(String agentId) { @@ -381,16 +400,22 @@ private Map loadTableDocumentMap(ExplorerContext context, List } private Map toTableEntry(String tableName, Document tableDocument, - List explicitSelectedTables) { + List explicitSelectedTables, List relations) { Map tableEntry = new LinkedHashMap<>(); tableEntry.put("name", tableName); tableEntry.put("selected", explicitSelectedTables.isEmpty() || explicitSelectedTables.stream() .anyMatch(candidate -> normalizeTableName(candidate).equals(normalizeTableName(tableName)))); + String unifiedForeignKeys = summarizeRelations(relations); if (tableDocument != null) { tableEntry.put("schema", tableDocument.getMetadata().getOrDefault("schema", "")); tableEntry.put("description", tableDocument.getMetadata().getOrDefault("description", "")); tableEntry.put("primaryKeys", tableDocument.getMetadata().getOrDefault("primaryKey", List.of())); - tableEntry.put("foreignKeys", tableDocument.getMetadata().getOrDefault("foreignKey", "")); + tableEntry.put("foreignKeys", + StringUtils.defaultIfBlank(unifiedForeignKeys, + String.valueOf(tableDocument.getMetadata().getOrDefault("foreignKey", "")))); + } + else if (StringUtils.isNotBlank(unifiedForeignKeys)) { + tableEntry.put("foreignKeys", unifiedForeignKeys); } return tableEntry; } @@ -418,26 +443,119 @@ private List parseSamples(String samples) { } } - private List filterRelations(List logicalRelations, String tableName) { - String normalizedTableName = normalizeTableName(tableName); - return logicalRelations.stream() - .filter(relation -> normalizedTableName.equals(normalizeTableName(relation.getSourceTableName())) - || normalizedTableName.equals(normalizeTableName(relation.getTargetTableName()))) - .sorted(Comparator.comparing(relation -> StringUtils.defaultString(relation.getSourceTableName()))) - .toList(); + private List filterRelations(ExplorerContext context, String tableName) { + return context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()); } - private Map toRelationEntry(LogicalRelation relation) { + private Map toRelationEntry(UnifiedRelation relation) { Map relationEntry = new LinkedHashMap<>(); - relationEntry.put("sourceTable", relation.getSourceTableName()); - relationEntry.put("sourceColumn", relation.getSourceColumnName()); - relationEntry.put("targetTable", relation.getTargetTableName()); - relationEntry.put("targetColumn", relation.getTargetColumnName()); - relationEntry.put("relationType", relation.getRelationType()); - relationEntry.put("description", relation.getDescription()); + relationEntry.put("sourceTable", relation.sourceTable()); + relationEntry.put("sourceColumn", relation.sourceColumn()); + relationEntry.put("targetTable", relation.targetTable()); + relationEntry.put("targetColumn", relation.targetColumn()); + relationEntry.put("relationType", relation.relationType()); + relationEntry.put("description", relation.description()); + relationEntry.put("sourceType", relation.sourceType()); + relationEntry.put("virtual", relation.virtual()); + relationEntry.put("declaredInDatabase", relation.declaredInDatabase()); return relationEntry; } + private List buildUnifiedRelations(Set visibleTableNameSet, + List physicalRelations, List logicalRelations) { + Map relationMap = new LinkedHashMap<>(); + for (ForeignKeyInfoBO physicalRelation : physicalRelations) { + UnifiedRelation relation = toUnifiedRelation(physicalRelation); + if (isVisibleRelation(visibleTableNameSet, relation)) { + mergeRelation(relationMap, relation); + } + } + for (LogicalRelation logicalRelation : logicalRelations) { + UnifiedRelation relation = toUnifiedRelation(logicalRelation); + if (isVisibleRelation(visibleTableNameSet, relation)) { + mergeRelation(relationMap, relation); + } + } + return relationMap.values() + .stream() + .sorted(Comparator.comparing((UnifiedRelation relation) -> normalizeTableName(relation.sourceTable())) + .thenComparing(relation -> StringUtils.defaultString(relation.sourceColumn())) + .thenComparing(relation -> normalizeTableName(relation.targetTable())) + .thenComparing(relation -> StringUtils.defaultString(relation.targetColumn()))) + .toList(); + } + + private UnifiedRelation toUnifiedRelation(ForeignKeyInfoBO relation) { + return new UnifiedRelation(relation.getTable(), relation.getColumn(), relation.getReferencedTable(), + relation.getReferencedColumn(), StringUtils.EMPTY, StringUtils.EMPTY, "physical", false, true); + } + + private UnifiedRelation toUnifiedRelation(LogicalRelation relation) { + return new UnifiedRelation(relation.getSourceTableName(), relation.getSourceColumnName(), + relation.getTargetTableName(), relation.getTargetColumnName(), + StringUtils.defaultString(relation.getRelationType()), StringUtils.defaultString(relation.getDescription()), + "logical", true, false); + } + + private boolean isVisibleRelation(Set visibleTableNameSet, UnifiedRelation relation) { + return visibleTableNameSet.contains(normalizeTableName(relation.sourceTable())) + && visibleTableNameSet.contains(normalizeTableName(relation.targetTable())); + } + + private void mergeRelation(Map relationMap, UnifiedRelation incoming) { + String relationKey = buildRelationKey(incoming); + UnifiedRelation existing = relationMap.get(relationKey); + if (existing == null) { + relationMap.put(relationKey, incoming); + return; + } + if (existing.declaredInDatabase() && !incoming.declaredInDatabase()) { + relationMap.put(relationKey, mergeRelation(existing, incoming)); + return; + } + if (!existing.declaredInDatabase() && incoming.declaredInDatabase()) { + relationMap.put(relationKey, mergeRelation(incoming, existing)); + return; + } + relationMap.put(relationKey, mergeRelation(existing, incoming)); + } + + private UnifiedRelation mergeRelation(UnifiedRelation preferred, UnifiedRelation supplement) { + return new UnifiedRelation(preferred.sourceTable(), preferred.sourceColumn(), preferred.targetTable(), + preferred.targetColumn(), StringUtils.firstNonBlank(preferred.relationType(), supplement.relationType()), + StringUtils.firstNonBlank(preferred.description(), supplement.description()), preferred.sourceType(), + preferred.virtual(), preferred.declaredInDatabase()); + } + + private String buildRelationKey(UnifiedRelation relation) { + return normalizeTableName(relation.sourceTable()) + "|" + StringUtils.defaultString(relation.sourceColumn()) + + "|" + normalizeTableName(relation.targetTable()) + "|" + + StringUtils.defaultString(relation.targetColumn()); + } + + private Map> indexRelationsByTable(List relations) { + Map> relationIndex = new LinkedHashMap<>(); + for (UnifiedRelation relation : relations) { + relationIndex.computeIfAbsent(normalizeTableName(relation.sourceTable()), key -> new ArrayList<>()) + .add(relation); + String targetKey = normalizeTableName(relation.targetTable()); + if (!targetKey.equals(normalizeTableName(relation.sourceTable()))) { + relationIndex.computeIfAbsent(targetKey, key -> new ArrayList<>()).add(relation); + } + } + Map> immutableIndex = new LinkedHashMap<>(); + relationIndex.forEach((tableName, tableRelations) -> immutableIndex.put(tableName, List.copyOf(tableRelations))); + return Map.copyOf(immutableIndex); + } + + private String summarizeRelations(List relations) { + return relations.stream() + .map(relation -> relation.sourceTable() + "." + relation.sourceColumn() + "=" + relation.targetTable() + + "." + relation.targetColumn()) + .distinct() + .collect(Collectors.joining("、")); + } + private List> toColumnHeaders(ResultSetBO resultSet) { return resultSet.getColumn().stream().map(column -> Map.of("name", column)).toList(); } @@ -483,7 +601,13 @@ private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(Expl private record ExplorerContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, Accessor accessor, List visibleTables, Set visibleTableNameSet, - List explicitSelectedTables, List logicalRelations) { + List explicitSelectedTables, List unifiedRelations, + Map> relationsByTable) { + } + + private record UnifiedRelation(String sourceTable, String sourceColumn, String targetTable, String targetColumn, + String relationType, String description, String sourceType, boolean virtual, + boolean declaredInDatabase) { } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index c150727ad..a470267da 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -133,12 +133,15 @@ private String buildDescription(Datasource datasource, AgentDatasource agentData .formatted(selectedTables.size(), String.join(", ", selectedTables.stream().limit(8).toList())); return """ Unified explorer for datasource '%s' (%s). - Use this tool to inspect tables, inspect schema, preview rows, and execute readonly SQL search. + Use this tool to inspect tables, inspect schema, inspect unified table relations, preview rows, and execute readonly SQL search. Constraints: 1. Only the current agent's active datasource is visible. 2. Only readonly SQL is allowed for SEARCH. - 3. Recommended call order: LIST_TABLES -> GET_TABLE_SCHEMA -> PREVIEW_ROWS -> SEARCH. - 4. %s + 3. GET_TABLE_SCHEMA and GET_RELATED_TABLES return a unified relations field that combines physical foreign keys discovered from the database and configured logical relations. + 4. Treat the unified relations field as the primary source for table-to-table relationship reasoning and join planning. + 5. The foreignKeys field inside table metadata is kept only for compatibility; prefer relations for agent reasoning. + 6. Recommended call order: LIST_TABLES -> GET_TABLE_SCHEMA -> GET_RELATED_TABLES -> PREVIEW_ROWS -> SEARCH. + 7. %s """.formatted(datasource.getName(), datasource.getType(), visibleTables); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java deleted file mode 100644 index 13823979a..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.aop; - -import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.ExceptionHandler; -import org.springframework.web.bind.annotation.RestControllerAdvice; - -@Slf4j -@RestControllerAdvice -public class ExceptionAdvice { - - @ExceptionHandler(Exception.class) - public ResponseEntity handleException(Exception e) { - log.error("An error occurred: ", e); - return ResponseEntity.internalServerError().body(ApiResponse.error("An error occurred: " + e.getMessage())); - } - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java deleted file mode 100644 index 5137496b5..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.bo.schema; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class ResultBO { - - private ResultSetBO resultSet; - - private DisplayStyleBO displayStyle; - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java deleted file mode 100644 index b5e3258fb..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.controller; - -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -/** - * @author yingzi - * @since 2025/9/16 - */ -@RestController -@RequestMapping("/echo") -public class EchoController { - - /** - * 心跳检测 - */ - @GetMapping("ok") - public String ok() { - return "ok"; - } - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DatabaseUtil.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DatabaseUtil.java deleted file mode 100644 index 98481ea36..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DatabaseUtil.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.util; - -import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; -import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; -import com.alibaba.cloud.ai.dataagent.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; -import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; -import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Component; - -/** - * Utility class for processing database. - */ -@Slf4j -@Component -@AllArgsConstructor -public class DatabaseUtil { - - private final AccessorFactory accessorFactory; - - private final AgentDatasourceService agentDatasourceService; - - private final DatasourceService datasourceService; - - public DbConfigBO getAgentDbConfig(Long agentId) { - log.info("Getting datasource config for agent: {}", agentId); - - // Get the enabled data source for the agent - AgentDatasource activeDatasource = agentDatasourceService.getCurrentAgentDatasource(agentId); - // Convert to DbConfig - DbConfigBO dbConfig = datasourceService.getDbConfig(activeDatasource.getDatasource()); - log.info("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId, dbConfig.getUrl(), - dbConfig.getSchema(), dbConfig.getDialectType()); - - return dbConfig; - } - - public Accessor getAgentAccessor(Long agentId) { - DbConfigBO dbConfig = getAgentDbConfig(agentId); - return accessorFactory.getAccessorByDbConfig(dbConfig); - } - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DateTimeUtil.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DateTimeUtil.java deleted file mode 100644 index d269b3d9e..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DateTimeUtil.java +++ /dev/null @@ -1,1118 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.util; - -import org.apache.commons.lang3.StringUtils; - -import java.time.DayOfWeek; -import java.time.LocalDate; -import java.time.YearMonth; -import java.time.format.DateTimeFormatter; -import java.time.temporal.IsoFields; -import java.time.temporal.TemporalAdjusters; -import java.util.ArrayList; -import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -public class DateTimeUtil { - - public static final Pattern SPECIFIC_YEAR_MONTH_DAY_PATTERN = Pattern.compile("\\d{4}年\\d{2}月\\d{2}日"); - - public static final Pattern GENERAL_YEAR_MONTH_DAY_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)(\\d{2}月\\d{2}日)"); - - public static final Pattern GENERAL_MONTH_DAY_PATTERN = Pattern.compile("(本月|上月|上上月|下月)(\\d{2}日)"); - - public static final Pattern GENERAL_DAY_PATTERN = Pattern.compile("(今天|昨天|前天|明天|后天|上月今天|上上月今天)"); - - public static final Pattern WEEK_DAY_PATTERN = Pattern.compile("本周第(\\d)天"); - - public static final Pattern GENERAL_MONTH_LAST_DAY_PATTERN = Pattern.compile("(本月|上月)最后一天"); - - public static final Pattern GENERAL_YEAR_MONTH_LAST_DAY_PATTERN = Pattern.compile("(今年)(\\d{2})月最后一天"); - - public static final Pattern GENERAL_WEEK_SPECIFIC_DAY_PATTERN = Pattern.compile("(本周|上周|上上周|下周|下下周)星期(\\d)"); - - public static final Pattern SPECIFIC_YEAR_MONTH_PATTERN = Pattern.compile("\\d{4}年\\d{2}月"); - - public static final Pattern GENERAL_YEAR_MONTH_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)(\\d{2}月)"); - - public static final Pattern GENERAL_MONTH_PATTERN = Pattern.compile("(本月|上月|上上月|下月|去年本月)"); - - public static final Pattern SPECIFIC_YEAR_PATTERN = Pattern.compile("(\\d{4})年"); - - public static final Pattern GENERAL_YEAR_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)"); - - public static final Pattern SPECIFIC_YEAR_QUARTER_PATTERN = Pattern.compile("\\d{4}年第\\d季度"); - - public static final Pattern GENERAL_YEAR_QUARTER_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)(第\\d季度)"); - - public static final Pattern GENERAL_QUARTER_PATTERN = Pattern.compile("(本季度|上季度|下季度|去年本季度)"); - - public static final Pattern GENERAL_WEEK_PATTERN = Pattern.compile("(本周|上周|上上周|下周|下下周)"); - - public static final Pattern SPECIFIC_YEAR_WEEK_PATTERN = Pattern.compile("(\\d{4})年第(\\d{2})周"); - - public static final Pattern SPECIFIC_YEAR_MONTH_WEEK_PATTERN = Pattern.compile("(\\d{4})年(\\d{2})月第(\\d)周"); - - public static final Pattern GENERAL_YEAR_WEEK_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)第(\\d{2})周"); - - public static final Pattern GENERAL_MONTH_WEEK_PATTERN = Pattern.compile("(本月|上月)第(\\d)周"); - - public static final Pattern GENERAL_YEAR_MONTH_WEEK_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)(\\d{2})月第(\\d)周"); - - public static final Pattern SPECIFIC_YEAR_MONTH_LAST_WEEK_PATTERN = Pattern.compile("(\\d{4})年(\\d{2})月最后一周"); - - public static final Pattern GENERAL_MONTH_LAST_WEEK_PATTERN = Pattern.compile("(本月|上月|上上月)最后一周"); - - public static final Pattern SPECIFIC_YEAR_MONTH_COMPLETE_WEEK_PATTERN = Pattern - .compile("(\\d{4})年(\\d{2})月第(\\d)个完整周"); - - public static final Pattern GENERAL_YEAR_COMPLETE_WEEK_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)第(\\d{2})个完整周"); - - public static final Pattern SPECIFIC_YEAR_COMPLETE_WEEK_PATTERN = Pattern.compile("(\\d{4})年第(\\d{2})个完整周"); - - public static final Pattern GENERAL_YEAR_MONTH_COMPLETE_WEEK_PATTERN = Pattern - .compile("(今年|去年|前年|明年|后年)(\\d{2})月第(\\d)个完整周"); - - public static final Pattern GENERAL_MONTH_COMPLETE_WEEK_PATTERN = Pattern.compile("(本月|上月)第(\\d)个完整周"); - - public static final Pattern GENERAL_MONTH_LAST_COMPLETE_WEEK_PATTERN = Pattern.compile("(本月|上月|上上月)最后一个完整周"); - - public static final Pattern RECENT_N_YEAR_PATTERN = Pattern.compile("近(\\d+)年"); - - public static final Pattern RECENT_N_MONTH_PATTERN = Pattern.compile("近(\\d+)个月"); - - public static final Pattern RECENT_N_WEEK_PATTERN = Pattern.compile("近(\\d+)周"); - - public static final Pattern RECENT_N_DAY_PATTERN = Pattern.compile("近(\\d+)天"); - - public static final Pattern RECENT_N_COMPLETE_YEAR_PATTERN = Pattern.compile("近(\\d+)个完整年"); - - public static final Pattern RECENT_N_COMPLETE_QUARTER_PATTERN = Pattern.compile("近(\\d+)个完整季度"); - - public static final Pattern RECENT_N_COMPLETE_MONTH_PATTERN = Pattern.compile("近(\\d+)个完整月"); - - public static final Pattern RECENT_N_COMPLETE_WEEK_PATTERN = Pattern.compile("近(\\d+)个完整周"); - - public static final Pattern RECENT_N_DAY_WITHOUT_TODAY_PATTERN = Pattern.compile("不包含今天的近(\\d+)天"); - - public static final Pattern RECENT_N_QUARTER_WITH_CURRENT_PATTERN = Pattern.compile("包含当前季度的近(\\d+)个季度"); - - public static final Pattern SPECIFIC_YEAR_HALF_YEAR_PATTERN = Pattern.compile("(\\d{4})年(上|下)半年"); - - public static final Pattern GENERAL_YEAR_HALF_YEAR_PATTERN = Pattern.compile("(今年|去年|前年|明年|后年)(上|下)半年"); - - public static final Pattern HALF_YEAR_PATTERN = Pattern.compile("(上|下)半年"); - - public static String buildDateTimeComment(List expressions) { - LocalDate now = LocalDate.now(); - // Get year, month, day - int year = now.getYear(); - int month = now.getMonthValue(); - int day = now.getDayOfMonth(); - - // Get current year's quarter - int quarter = now.get(IsoFields.QUARTER_OF_YEAR); - - String todayComment = String.format("今天是%d年%02d月%02d日,是%d年的第%d季度", year, month, day, year, quarter); - - List dateTimeCommentList = buildDateExpressions(expressions, now); - - StringBuilder finalExpression = new StringBuilder(); - finalExpression.append(todayComment).append("\n"); - finalExpression.append("需要计算的时间是:\n"); - dateTimeCommentList.forEach(comment -> finalExpression.append(comment).append("\n")); - return finalExpression.toString(); - } - - public static List buildDateExpressions(List expressions, LocalDate now) { - List dateTimeCommentList = new ArrayList<>(); - for (String expression : expressions) { - Matcher specificYearMonthDayMatcher = SPECIFIC_YEAR_MONTH_DAY_PATTERN.matcher(expression); - if (specificYearMonthDayMatcher.matches()) { - dateTimeCommentList.add(expression + "=" + expression); - continue; - } - - Matcher generalYearMonthDayMatcher = GENERAL_YEAR_MONTH_DAY_PATTERN.matcher(expression); - if (generalYearMonthDayMatcher.matches()) { - String yearEx = generalYearMonthDayMatcher.group(1); - String comment = getYearEx(now, yearEx, false) + generalYearMonthDayMatcher.group(2); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthDayMatcher = GENERAL_MONTH_DAY_PATTERN.matcher(expression); - if (generalMonthDayMatcher.matches()) { - String monthEx = generalMonthDayMatcher.group(1); - String comment = getMonthEx(now, monthEx) + generalMonthDayMatcher.group(2); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher yearMonthLastDayMatcher = GENERAL_YEAR_MONTH_LAST_DAY_PATTERN.matcher(expression); - if (yearMonthLastDayMatcher.matches()) { - String yearEx = yearMonthLastDayMatcher.group(1); - String monthEx = yearMonthLastDayMatcher.group(2); - String comment = getGeneralYearMonthLastDayEx(now, yearEx, Integer.valueOf(monthEx)); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher monthLastDayMatcher = GENERAL_MONTH_LAST_DAY_PATTERN.matcher(expression); - if (monthLastDayMatcher.matches()) { - String monthEx = monthLastDayMatcher.group(1); - String comment = getMonthLastDayEx(now, monthEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher weekDayMatcher = WEEK_DAY_PATTERN.matcher(expression); - if (weekDayMatcher.matches()) { - int weekDay = Integer.parseInt(weekDayMatcher.group(1)); - String comment = getWeekDayEx(now, weekDay); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalWeekDayMatcher = GENERAL_WEEK_SPECIFIC_DAY_PATTERN.matcher(expression); - if (generalWeekDayMatcher.matches()) { - String weekEx = generalWeekDayMatcher.group(1); - int day = Integer.parseInt(generalWeekDayMatcher.group(2)); - String comment = getGeneralWeekDayEx(now, weekEx, day); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearQuarterMatcher = SPECIFIC_YEAR_QUARTER_PATTERN.matcher(expression); - if (specificYearQuarterMatcher.matches()) { - dateTimeCommentList.add(expression + "=" + expression); - continue; - } - - Matcher generalYearQuarterMatcher = GENERAL_YEAR_QUARTER_PATTERN.matcher(expression); - if (generalYearQuarterMatcher.matches()) { - String yearEx = generalYearQuarterMatcher.group(1); - String quarterEx = generalYearQuarterMatcher.group(2); - String comment = getYearEx(now, yearEx, false) + quarterEx; - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalQuarterMatcher = GENERAL_QUARTER_PATTERN.matcher(expression); - if (generalQuarterMatcher.matches()) { - String quarterEx = generalQuarterMatcher.group(1); - String comment = getQuarterEx(now, quarterEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalWeekMatcher = GENERAL_WEEK_PATTERN.matcher(expression); - if (generalWeekMatcher.matches()) { - String weekEx = generalWeekMatcher.group(1); - String comment = getWeekEx(now, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearWeekMatcher = SPECIFIC_YEAR_WEEK_PATTERN.matcher(expression); - if (specificYearWeekMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearWeekMatcher.group(1)); - int weekEx = Integer.parseInt(specificYearWeekMatcher.group(2)); - String comment = getSpecificYearWeekEx(now, yearEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearWeekMatcher = GENERAL_YEAR_WEEK_PATTERN.matcher(expression); - if (generalYearWeekMatcher.matches()) { - String yearEx = generalYearWeekMatcher.group(1); - int weekEx = Integer.parseInt(generalYearWeekMatcher.group(2)); - String comment = getGeneralYearWeekEx(now, yearEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthWeekMatcher = GENERAL_MONTH_WEEK_PATTERN.matcher(expression); - if (generalMonthWeekMatcher.matches()) { - String monthEx = generalMonthWeekMatcher.group(1); - int weekEx = Integer.parseInt(generalMonthWeekMatcher.group(2)); - String comment = getGeneralMonthWeekEx(now, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearMonthLastWeekMatcher = SPECIFIC_YEAR_MONTH_LAST_WEEK_PATTERN.matcher(expression); - if (specificYearMonthLastWeekMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearMonthLastWeekMatcher.group(1)); - int monthEx = Integer.parseInt(specificYearMonthLastWeekMatcher.group(2)); - String comment = getSpecificYearMonthLastWeek(now, yearEx, monthEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthLastWeekMatcher = GENERAL_MONTH_LAST_WEEK_PATTERN.matcher(expression); - if (generalMonthLastWeekMatcher.matches()) { - String monthEx = generalMonthLastWeekMatcher.group(1); - String comment = getGeneralMonthLastWeek(now, monthEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthLastCompleteWeekMatcher = GENERAL_MONTH_LAST_COMPLETE_WEEK_PATTERN.matcher(expression); - if (generalMonthLastCompleteWeekMatcher.matches()) { - String monthEx = generalMonthLastCompleteWeekMatcher.group(1); - String comment = getGeneralMonthLastCompleteWeekEx(now, monthEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNYearMatcher = RECENT_N_YEAR_PATTERN.matcher(expression); - if (recentNYearMatcher.matches()) { - int n = Integer.parseInt(recentNYearMatcher.group(1)); - String comment = getRecentNYear(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNMonthMatcher = RECENT_N_MONTH_PATTERN.matcher(expression); - if (recentNMonthMatcher.matches()) { - int n = Integer.parseInt(recentNMonthMatcher.group(1)); - String comment = getRecentNMonth(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNWeekMatcher = RECENT_N_WEEK_PATTERN.matcher(expression); - if (recentNWeekMatcher.matches()) { - int n = Integer.parseInt(recentNWeekMatcher.group(1)); - String comment = getRecentNWeek(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNDayWithoutTodayMatcher = RECENT_N_DAY_WITHOUT_TODAY_PATTERN.matcher(expression); - if (recentNDayWithoutTodayMatcher.matches()) { - int n = Integer.parseInt(recentNDayWithoutTodayMatcher.group(1)); - String comment = getRecentNDayWithoutToday(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNDayMatcher = RECENT_N_DAY_PATTERN.matcher(expression); - if (recentNDayMatcher.matches()) { - int n = Integer.parseInt(recentNDayMatcher.group(1)); - String comment = getRecentNDay(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNCompleteYearMatcher = RECENT_N_COMPLETE_YEAR_PATTERN.matcher(expression); - if (recentNCompleteYearMatcher.matches()) { - int n = Integer.parseInt(recentNCompleteYearMatcher.group(1)); - String comment = getRecentNCompleteYear(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNCompleteQuarterMatcher = RECENT_N_COMPLETE_QUARTER_PATTERN.matcher(expression); - if (recentNCompleteQuarterMatcher.matches()) { - int n = Integer.parseInt(recentNCompleteQuarterMatcher.group(1)); - String comment = getRecentNCompleteQuarter(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNCompleteMonthMatcher = RECENT_N_COMPLETE_MONTH_PATTERN.matcher(expression); - if (recentNCompleteMonthMatcher.matches()) { - int n = Integer.parseInt(recentNCompleteMonthMatcher.group(1)); - String comment = getRecentNCompleteMonth(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNCompleteWeekMatcher = RECENT_N_COMPLETE_WEEK_PATTERN.matcher(expression); - if (recentNCompleteWeekMatcher.matches()) { - int n = Integer.parseInt(recentNCompleteWeekMatcher.group(1)); - String comment = getRecentNCompleteWeek(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher recentNQuarterWithCurrentMatcher = RECENT_N_QUARTER_WITH_CURRENT_PATTERN.matcher(expression); - if (recentNQuarterWithCurrentMatcher.matches()) { - int n = Integer.parseInt(recentNQuarterWithCurrentMatcher.group(1)); - String comment = getRecentNQuarterWithCurrent(now, n); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearMonthMatcher = SPECIFIC_YEAR_MONTH_PATTERN.matcher(expression); - if (specificYearMonthMatcher.matches()) { - dateTimeCommentList.add(expression + "=" + expression); - continue; - } - - Matcher generalYearMonthMatcher = GENERAL_YEAR_MONTH_PATTERN.matcher(expression); - if (generalYearMonthMatcher.matches()) { - String yearEx = generalYearMonthMatcher.group(1); - String comment = getYearEx(now, yearEx, false) + generalYearMonthMatcher.group(2); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalDayMatcher = GENERAL_DAY_PATTERN.matcher(expression); - if (generalDayMatcher.matches()) { - String dayEx = generalDayMatcher.group(1); - String comment = getDayEx(now, dayEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthMatcher = GENERAL_MONTH_PATTERN.matcher(expression); - if (generalMonthMatcher.matches()) { - String monthEx = generalMonthMatcher.group(1); - String comment = getMonthEx(now, monthEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearMatcher = SPECIFIC_YEAR_PATTERN.matcher(expression); - if (specificYearMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearMatcher.group(1)); - String comment = String.valueOf(yearEx) + "年"; - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearMatcher = GENERAL_YEAR_PATTERN.matcher(expression); - if (generalYearMatcher.matches()) { - String yearEx = generalYearMatcher.group(1); - String comment = getYearEx(now, yearEx, true); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearMonthWeekMatcher = SPECIFIC_YEAR_MONTH_WEEK_PATTERN.matcher(expression); - if (specificYearMonthWeekMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearMonthWeekMatcher.group(1)); - int monthEx = Integer.parseInt(specificYearMonthWeekMatcher.group(2)); - int weekEx = Integer.parseInt(specificYearMonthWeekMatcher.group(3)); - String comment = getSpecificYearMonthWeekEx(now, yearEx, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearMonthWeekMatcher = GENERAL_YEAR_MONTH_WEEK_PATTERN.matcher(expression); - if (generalYearMonthWeekMatcher.matches()) { - String yearEx = generalYearMonthWeekMatcher.group(1); - int monthEx = Integer.parseInt(generalYearMonthWeekMatcher.group(2)); - int weekEx = Integer.parseInt(generalYearMonthWeekMatcher.group(3)); - String comment = getGeneralYearMonthWeekEx(now, yearEx, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearMonthCompleteWeekMatcher = SPECIFIC_YEAR_MONTH_COMPLETE_WEEK_PATTERN - .matcher(expression); - if (specificYearMonthCompleteWeekMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearMonthCompleteWeekMatcher.group(1)); - int monthEx = Integer.parseInt(specificYearMonthCompleteWeekMatcher.group(2)); - int weekEx = Integer.parseInt(specificYearMonthCompleteWeekMatcher.group(3)); - String comment = getSpecificYearMonthCompleteWeekEx(now, yearEx, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearMonthCompleteWeekMatcher = GENERAL_YEAR_MONTH_COMPLETE_WEEK_PATTERN.matcher(expression); - if (generalYearMonthCompleteWeekMatcher.matches()) { - String yearEx = generalYearMonthCompleteWeekMatcher.group(1); - int monthEx = Integer.parseInt(generalYearMonthCompleteWeekMatcher.group(2)); - int weekEx = Integer.parseInt(generalYearMonthCompleteWeekMatcher.group(3)); - String comment = getGeneralYearMonthCompleteWeekEx(now, yearEx, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalMonthCompleteWeekMatcher = GENERAL_MONTH_COMPLETE_WEEK_PATTERN.matcher(expression); - if (generalMonthCompleteWeekMatcher.matches()) { - String monthEx = generalMonthCompleteWeekMatcher.group(1); - int weekEx = Integer.parseInt(generalMonthCompleteWeekMatcher.group(2)); - String comment = getGeneralMonthCompleteWeekEx(now, monthEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearCompleteWeekMatcher = SPECIFIC_YEAR_COMPLETE_WEEK_PATTERN.matcher(expression); - if (specificYearCompleteWeekMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearCompleteWeekMatcher.group(1)); - int weekEx = Integer.parseInt(specificYearCompleteWeekMatcher.group(2)); - String comment = getSpecificYearCompleteWeekEx(now, yearEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearCompleteWeekMatcher = GENERAL_YEAR_COMPLETE_WEEK_PATTERN.matcher(expression); - if (generalYearCompleteWeekMatcher.matches()) { - String yearEx = generalYearCompleteWeekMatcher.group(1); - int weekEx = Integer.parseInt(generalYearCompleteWeekMatcher.group(2)); - String comment = getGeneralYearCompleteWeekEx(now, yearEx, weekEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher specificYearHalfYearMatcher = SPECIFIC_YEAR_HALF_YEAR_PATTERN.matcher(expression); - if (specificYearHalfYearMatcher.matches()) { - int yearEx = Integer.parseInt(specificYearHalfYearMatcher.group(1)); - String halfYearEx = specificYearHalfYearMatcher.group(2); - String comment = getSpecificYearHalfYearEx(now, yearEx, halfYearEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher generalYearHalfYearMatcher = GENERAL_YEAR_HALF_YEAR_PATTERN.matcher(expression); - if (generalYearHalfYearMatcher.matches()) { - String yearEx = generalYearHalfYearMatcher.group(1); - String halfYearEx = generalYearHalfYearMatcher.group(2); - String comment = getGeneralYearHalfYearEx(now, yearEx, halfYearEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - Matcher halfYearMatcher = HALF_YEAR_PATTERN.matcher(expression); - if (halfYearMatcher.matches()) { - String halfYearEx = halfYearMatcher.group(1); - String comment = getSpecificYearHalfYearEx(now, now.getYear(), halfYearEx); - dateTimeCommentList.add(expression + "=" + comment); - continue; - } - - } - - return dateTimeCommentList; - } - - public static String getYearEx(LocalDate now, String yearEx, boolean applyDomainLogic) { - String comment = ""; - int year = 0; - if (yearEx.equals("今年")) { - year = now.getYear(); - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - - comment = String.valueOf(year) + "年"; - - return comment; - } - - public static String getMonthEx(LocalDate now, String monthEx) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月"); - String comment = ""; - if (monthEx.equals("本月")) { - comment = formatter.format(YearMonth.from(now)); - } - else if (monthEx.equals("上月")) { - comment = formatter.format(YearMonth.from(now).minusMonths(1)); - } - else if (monthEx.equals("上上月")) { - comment = formatter.format(YearMonth.from(now).minusMonths(2)); - } - else if (monthEx.equals("下月")) { - comment = formatter.format(YearMonth.from(now).plusMonths(1)); - } - else if (monthEx.equals("去年本月")) { - comment = formatter.format(YearMonth.from(now).minusYears(1)); - } - return comment; - } - - public static String getDayEx(LocalDate now, String dayEx) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - String comment = ""; - try { - if (dayEx.equals("今天")) { - comment = formatter.format(now); - } - else if (dayEx.equals("昨天")) { - comment = formatter.format(now.minusDays(1)); - } - else if (dayEx.equals("前天")) { - comment = formatter.format(now.minusDays(2)); - } - else if (dayEx.equals("明天")) { - comment = formatter.format(now.plusDays(1)); - } - else if (dayEx.equals("后天")) { - comment = formatter.format(now.plusDays(2)); - } - else if (dayEx.equals("上月今天")) { - comment = formatter.format(YearMonth.from(now).minusMonths(1).atDay(now.getDayOfMonth())); - } - else if (dayEx.equals("上上月今天")) { - comment = formatter.format(YearMonth.from(now).minusMonths(2).atDay(now.getDayOfMonth())); - } - } - catch (Exception e) { - e.printStackTrace(); - } - return comment; - } - - public static final String getWeekDayEx(LocalDate now, int x) { - - // Calculate date of first day of week (Monday) - LocalDate monday = now.with(TemporalAdjusters.previousOrSame(DayOfWeek.MONDAY)); - // Get date of xth day of week by adding (x - 1) days - LocalDate desiredDay = monday.plusDays(x - 1); - - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(desiredDay); - } - - public static final String getGeneralWeekDayEx(LocalDate now, String weekEx, int day) { - LocalDate thisMonday = now.with(TemporalAdjusters.previousOrSame(DayOfWeek.MONDAY)); - LocalDate desiredDay = thisMonday.plusDays(day - 1); - if (weekEx.equals("本周")) { - - } - else if (weekEx.equals("上周")) { - desiredDay = desiredDay.minusWeeks(1); - } - else if (weekEx.equals("上上周")) { - desiredDay = desiredDay.minusWeeks(2); - } - else if (weekEx.equals("下周")) { - desiredDay = desiredDay.plusWeeks(1); - } - else if (weekEx.equals("下下周")) { - desiredDay = desiredDay.plusWeeks(2); - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(desiredDay); - } - - public static final String getWeekEx(LocalDate now, String weekEx) { - LocalDate desireMonday = now.with(TemporalAdjusters.previousOrSame(DayOfWeek.MONDAY)); - LocalDate desireSunday = now.with(TemporalAdjusters.nextOrSame(DayOfWeek.SUNDAY)); - if (weekEx.equals("本周")) { - - } - else if (weekEx.equals("上周")) { - desireMonday = desireMonday.minusWeeks(1); - desireSunday = desireSunday.minusWeeks(1); - } - else if (weekEx.equals("上上周")) { - desireMonday = desireMonday.minusWeeks(2); - desireSunday = desireSunday.minusWeeks(2); - } - else if (weekEx.equals("下周")) { - desireMonday = desireMonday.plusWeeks(1); - desireSunday = desireSunday.plusWeeks(1); - } - else if (weekEx.equals("下下周")) { - desireMonday = desireMonday.plusWeeks(2); - desireSunday = desireSunday.plusWeeks(2); - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(desireMonday) + "至" + formatter.format(desireSunday); - } - - public static String getSpecificYearWeekEx(LocalDate now, int year, int week) { - LocalDate firstDayOfYear = LocalDate.of(year, 1, 1); - LocalDate targetWeekFirstDay = firstDayOfYear.plusWeeks(week - 1); - LocalDate targetWeekLastDay = targetWeekFirstDay.plusDays(6); - LocalDate lastDayOfYear = firstDayOfYear.with(TemporalAdjusters.lastDayOfYear()); - if (lastDayOfYear.isBefore(targetWeekLastDay)) { - targetWeekLastDay = lastDayOfYear; - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(targetWeekFirstDay) + "至" + formatter.format(targetWeekLastDay); - } - - public static String getGeneralYearWeekEx(LocalDate now, String yearEx, int week) { - int year = now.getYear(); - if (yearEx.equals("今年")) { - - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - return getSpecificYearWeekEx(now, year, week); - } - - public static String getSpecificYearMonthWeekEx(LocalDate now, int year, int month, int week) { - LocalDate firstDayOfMonth = LocalDate.of(year, month, 1); - LocalDate targetWeekFirstDay = firstDayOfMonth.plusWeeks(week - 1); - LocalDate targetWeekLastDay = targetWeekFirstDay.plusDays(6); - LocalDate lastDayOfMonth = firstDayOfMonth.with(TemporalAdjusters.lastDayOfMonth()); - if (lastDayOfMonth.isBefore(targetWeekLastDay)) { - targetWeekLastDay = lastDayOfMonth; - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(targetWeekFirstDay) + "至" + formatter.format(targetWeekLastDay); - } - - public static String getGeneralYearMonthWeekEx(LocalDate now, String yearEx, int month, int week) { - int year = now.getYear(); - if (yearEx.equals("今年")) { - - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - return getSpecificYearMonthWeekEx(now, year, month, week); - } - - public static String getGeneralMonthWeekEx(LocalDate now, String monthEx, int week) { - int year = now.getYear(); - int month = now.getMonthValue(); - if (monthEx.equals("本月")) { - - } - else if (monthEx.equals("上月")) { - month = now.getMonthValue() - 1; - if (month <= 0) { - year--; - month = 12 + month; - } - } - LocalDate firstDayOfMonth = LocalDate.of(year, month, 1); - LocalDate targetWeekFirstDay = firstDayOfMonth.plusWeeks(week - 1); - LocalDate targetWeekLastDay = targetWeekFirstDay.plusDays(6); - LocalDate lastDayOfMonth = firstDayOfMonth.with(TemporalAdjusters.lastDayOfMonth()); - if (lastDayOfMonth.isBefore(targetWeekLastDay)) { - targetWeekLastDay = lastDayOfMonth; - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(targetWeekFirstDay) + "至" + formatter.format(targetWeekLastDay); - } - - public static String getSpecificYearMonthCompleteWeekEx(LocalDate now, int year, int month, int week) { - LocalDate firstDayOfMonth = LocalDate.of(year, month, 1); - LocalDate firstMonday = firstDayOfMonth.with(TemporalAdjusters.firstInMonth(DayOfWeek.MONDAY)); - LocalDate targetStartDate = firstMonday.plusWeeks(week - 1); - LocalDate targetEndDate = targetStartDate.plusDays(6); - LocalDate lastDayOfMonth = firstDayOfMonth.with(TemporalAdjusters.lastDayOfMonth()); - if (lastDayOfMonth.isBefore(targetEndDate)) { - return StringUtils.EMPTY; - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(targetStartDate) + "至" + formatter.format(targetEndDate); - } - - public static String getGeneralYearMonthCompleteWeekEx(LocalDate now, String yearEx, int month, int week) { - int year = now.getYear(); - if (yearEx.equals("今年")) { - - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - return getSpecificYearMonthCompleteWeekEx(now, year, month, week); - } - - public static String getGeneralMonthCompleteWeekEx(LocalDate now, String monthEx, int week) { - int year = now.getYear(); - int month = now.getMonthValue(); - if (monthEx.equals("本月")) { - - } - else if (monthEx.equals("上月")) { - month = now.getMonthValue() - 1; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("上上月")) { - month = now.getMonthValue() - 2; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("下月")) { - month = now.getMonthValue() + 1; - if (month > 12) { - year++; - month = month - 12; - } - } - return getSpecificYearMonthCompleteWeekEx(now, year, month, week); - } - - public static String getSpecificYearCompleteWeekEx(LocalDate now, int year, int week) { - LocalDate firstDayOfYear = LocalDate.of(year, 1, 1); - LocalDate firstMonday = firstDayOfYear.with(TemporalAdjusters.firstInMonth(DayOfWeek.MONDAY)); - LocalDate targetStartDate = firstMonday.plusWeeks(week - 1); - LocalDate targetEndDate = targetStartDate.plusDays(6); - LocalDate lastDayOfYear = firstDayOfYear.with(TemporalAdjusters.lastDayOfYear()); - if (lastDayOfYear.isBefore(targetEndDate)) { - return StringUtils.EMPTY; - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(targetStartDate) + "至" + formatter.format(targetEndDate); - } - - public static String getGeneralYearCompleteWeekEx(LocalDate now, String yearEx, int week) { - int year = now.getYear(); - if (yearEx.equals("今年")) { - - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - return getSpecificYearCompleteWeekEx(now, year, week); - } - - public static String getSpecificYearMonthLastWeek(LocalDate now, int year, int month) { - LocalDate firstDayOfMonth = LocalDate.of(year, month, 1); - LocalDate lastDayOfMonth = firstDayOfMonth.with(TemporalAdjusters.lastDayOfMonth()); - LocalDate previousMonday = lastDayOfMonth.with(TemporalAdjusters.previousOrSame(DayOfWeek.MONDAY)); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(previousMonday) + "至" + formatter.format(lastDayOfMonth); - } - - public static String getGeneralMonthLastWeek(LocalDate now, String monthEx) { - int year = now.getYear(); - int month = now.getMonthValue(); - if (monthEx.equals("本月")) { - - } - else if (monthEx.equals("上月")) { - month = now.getMonthValue() - 1; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("上上月")) { - month = now.getMonthValue() - 2; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("下月")) { - month = now.getMonthValue() + 1; - if (month > 12) { - year++; - month = month - 12; - } - } - return getSpecificYearMonthLastWeek(now, year, month); - } - - public static String getSpecificYearMonthLastCompleteWeekEx(LocalDate now, int year, int month) { - LocalDate firstDayOfMonth = LocalDate.of(year, month, 1); - LocalDate lastSunday = firstDayOfMonth.with(TemporalAdjusters.lastInMonth(DayOfWeek.SUNDAY)); - LocalDate lastMonday = lastSunday.minusDays(6); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(lastMonday) + "至" + formatter.format(lastSunday); - } - - public static String getGeneralMonthLastCompleteWeekEx(LocalDate now, String monthEx) { - int year = now.getYear(); - int month = now.getMonthValue(); - if (monthEx.equals("本月")) { - - } - else if (monthEx.equals("上月")) { - month = now.getMonthValue() - 1; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("上上月")) { - month = now.getMonthValue() - 2; - if (month <= 0) { - year--; - month = 12 + month; - } - } - else if (monthEx.equals("下月")) { - month = now.getMonthValue() + 1; - if (month > 12) { - year++; - month = month - 12; - } - } - return getSpecificYearMonthLastCompleteWeekEx(now, year, month); - } - - public static String getQuarterEx(LocalDate now, String quarterEx) { - int currentQuarter = now.get(IsoFields.QUARTER_OF_YEAR); - // Calculate previous and next quarters - int lastQuarter = currentQuarter == 1 ? 4 : currentQuarter - 1; - int nextQuarter = currentQuarter == 4 ? 1 : currentQuarter + 1; - int currentYear = now.getYear(); - int yearOfLastQuarter = (currentQuarter == 1) ? currentYear - 1 : currentYear; - int yearOfNextQuarter = (currentQuarter == 4) ? currentYear + 1 : currentYear; - int yearOfSameQuarterLastYear = currentYear - 1; - - String comment = ""; - if (quarterEx.equals("本季度")) { - comment = currentYear + "年第" + currentQuarter + "季度"; - } - else if (quarterEx.equals("上季度")) { - comment = yearOfLastQuarter + "年第" + lastQuarter + "季度"; - } - else if (quarterEx.equals("下季度")) { - comment = yearOfNextQuarter + "年第" + nextQuarter + "季度"; - } - else if (quarterEx.equals("去年本季度")) { - comment = yearOfSameQuarterLastYear + "年第" + currentQuarter + "季度"; - } - return comment; - } - - public static String getRecentNYear(LocalDate now, int n) { - LocalDate startDate = now.minusYears(n); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(now); - } - - public static String getRecentNMonth(LocalDate now, int n) { - LocalDate startDate = now.minusMonths(n); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(now); - } - - public static String getRecentNWeek(LocalDate now, int n) { - LocalDate startDate = now.minusWeeks(n); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(now); - } - - public static String getRecentNDay(LocalDate now, int n) { - LocalDate startDate = now.minusDays(n); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(now); - } - - public static String getRecentNCompleteYear(LocalDate now, int n) { - LocalDate endDate; - if (now.getMonthValue() == 12 && now.getDayOfMonth() == 31) { - endDate = now; - } - else { - endDate = now.with(TemporalAdjusters.lastDayOfYear()).minusYears(1); - } - LocalDate startDate = endDate.minusYears(n).plusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getRecentNCompleteMonth(LocalDate now, int n) { - LocalDate endDate; - if (now.equals(now.with(TemporalAdjusters.lastDayOfMonth()))) { - endDate = now; - } - else { - endDate = now.with(TemporalAdjusters.firstDayOfMonth()) - .minusMonths(1) - .with(TemporalAdjusters.lastDayOfMonth()); - } - LocalDate startDate = endDate.minusMonths(n).plusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getRecentNCompleteQuarter(LocalDate now, int n) { - LocalDate endDate; - int currentMonth = now.getMonthValue(); - if (currentMonth % 4 == 0 && now.getDayOfMonth() == 31) { - endDate = now; - } - else { - if (currentMonth < 4) { - endDate = LocalDate.of(now.getYear() - 1, 12, 31); - } - else if (currentMonth < 7) { - endDate = LocalDate.of(now.getYear(), 3, 31); - } - else if (currentMonth < 10) { - endDate = LocalDate.of(now.getYear(), 6, 30); - } - else { - endDate = LocalDate.of(now.getYear(), 9, 30); - } - } - LocalDate startDate = endDate.minusMonths(n * 3).plusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getRecentNCompleteWeek(LocalDate now, int n) { - LocalDate endDate; - if (now.getDayOfWeek().getValue() == 7) { - endDate = now; - } - else { - endDate = now.minusWeeks(1).with(TemporalAdjusters.nextOrSame(DayOfWeek.SUNDAY)); - } - LocalDate startDate = endDate.minusWeeks(n).plusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getRecentNDayWithoutToday(LocalDate now, int n) { - LocalDate startDate = now.minusDays(n); - LocalDate endDate = now.minusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getRecentNQuarterWithCurrent(LocalDate now, int n) { - LocalDate endDate; - int currentMonth = now.getMonthValue(); - if (currentMonth < 4) { - endDate = LocalDate.of(now.getYear(), 3, 31); - } - else if (currentMonth < 7) { - endDate = LocalDate.of(now.getYear(), 6, 30); - } - else if (currentMonth < 10) { - endDate = LocalDate.of(now.getYear(), 9, 30); - } - else { - endDate = LocalDate.of(now.getYear(), 12, 31); - } - LocalDate startDate = endDate.minusMonths(n * 3).plusDays(1); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getMonthLastDayEx(LocalDate now, String monthEx) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - String comment = ""; - if (monthEx.equals("本月")) { - comment = formatter.format(YearMonth.from(now).atEndOfMonth()); - } - else if (monthEx.equals("上月")) { - comment = formatter.format(YearMonth.from(now).minusMonths(1).atEndOfMonth()); - } - return comment; - } - - public static String getGeneralYearMonthLastDayEx(LocalDate now, String yearEx, int month) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - String comment = ""; - int year = 0; - if (yearEx.equals("今年")) { - year = now.getYear(); - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - comment = formatter.format(YearMonth.of(year, month).atEndOfMonth()); - return comment; - } - - public static String getSpecificYearHalfYearEx(LocalDate now, int year, String halfYearEx) { - LocalDate startDate; - LocalDate endDate; - if (halfYearEx.equals("上")) { - startDate = LocalDate.of(year, 1, 1); - endDate = LocalDate.of(year, 6, 30); - } - else { - startDate = LocalDate.of(year, 7, 1); - endDate = LocalDate.of(year, 12, 31); - } - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy年MM月dd日"); - return formatter.format(startDate) + "至" + formatter.format(endDate); - } - - public static String getGeneralYearHalfYearEx(LocalDate now, String yearEx, String halfYearEx) { - int year = 0; - if (yearEx.equals("今年")) { - year = now.getYear(); - } - else if (yearEx.equals("去年")) { - year = now.getYear() - 1; - } - else if (yearEx.equals("前年")) { - year = now.getYear() - 2; - } - else if (yearEx.equals("明年")) { - year = now.getYear() + 1; - } - else if (yearEx.equals("后年")) { - year = now.getYear() + 2; - } - return getSpecificYearHalfYearEx(now, year, halfYearEx); - } - -} From 00249e35c49a1e5812ce6a93b2803c12a8c1263d Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 22 Apr 2026 22:09:49 +0800 Subject: [PATCH 02/22] feat: add sql check tool --- .../tool/sqlguard/SqlGuardCheckRequest.java | 34 + .../tool/sqlguard/SqlGuardCheckResult.java | 57 ++ .../tool/sqlguard/SqlGuardProblem.java | 43 ++ .../tool/sqlguard/SqlGuardRuleCheck.java | 35 + .../tool/sqlguard/SqlGuardToolProvider.java | 127 ++++ .../sqlguard/SqlVerifyExplainService.java | 711 ++++++++++++++++++ .../src/main/resources/prompts/commonagent.md | 11 + 7 files changed, 1018 insertions(+) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java new file mode 100644 index 000000000..078c4f8c5 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.fasterxml.jackson.databind.JsonNode; +import lombok.Data; + +@Data +class SqlGuardCheckRequest { + + private String query; + + private String sql; + + private JsonNode tableSchemas; + + private JsonNode semanticHits; + + private JsonNode businessKnowledgeHits; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java new file mode 100644 index 000000000..36a530b05 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java @@ -0,0 +1,57 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Builder; +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +@Data +@Builder +@JsonInclude(JsonInclude.Include.NON_NULL) +class SqlGuardCheckResult { + + private String query; + + private String sql; + + private String summary; + + private String explainedIntent; + + @JsonProperty("isAligned") + private boolean isAligned; + + @Builder.Default + private List problems = new ArrayList<>(); + + @Builder.Default + private List fixSuggestions = new ArrayList<>(); + + @Builder.Default + private List usedTables = new ArrayList<>(); + + @Builder.Default + private List usedMetrics = new ArrayList<>(); + + @Builder.Default + private List ruleChecks = new ArrayList<>(); + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java new file mode 100644 index 000000000..64dfa5b6b --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java @@ -0,0 +1,43 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +class SqlGuardProblem { + + private String code; + + private String title; + + private String severity; + + private String message; + + private String why; + + private String expected; + + private String actual; + + private String evidence; + + private String repairHint; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java new file mode 100644 index 000000000..e625aba7c --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +class SqlGuardRuleCheck { + + private String code; + + private String title; + + private String status; + + private String detail; + + private String evidence; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java new file mode 100644 index 000000000..0c51e029d --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -0,0 +1,127 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Map; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Component +public class SqlGuardToolProvider implements AgentScopedToolProvider { + + private static final String TOOL_NAME = "sql_guard.check"; + + private static final String INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "必填。用户原始问题。" + }, + "sql": { + "type": "string", + "description": "必填。当前准备执行或准备返回给用户的候选 SQL。" + }, + "tableSchemas": { + "type": "object", + "description": "可选。把 datasource explorer 的 schema 结果原样传入,帮助识别时间列、维度列与表关系。" + }, + "semanticHits": { + "type": "object", + "description": "可选。把 semantic_model.search 的结果原样传入。" + }, + "businessKnowledgeHits": { + "type": "object", + "description": "可选。把 domain_business_knowledge.search 的结果原样传入。" + } + }, + "required": ["query", "sql"] + } + """; + + private static final String DESCRIPTION = """ + Single SQL verification tool for SQL-backed answers. + Check whether the candidate SQL really matches the user's intent before execution or final answer. + If verification fails, read isAligned=false plus problems, ruleChecks and fixSuggestions, then rewrite SQL yourself and call sql_guard.check again. + Each problem explains why it is wrong, what was expected, what was actually detected and how to repair it. + Always pass a fresh top-level query and sql. Do not pass previous sql_guard.check output back into the tool. + """; + + private final ObjectMapper objectMapper; + + private final SqlVerifyExplainService sqlVerifyExplainService; + + public SqlGuardToolProvider(ObjectMapper objectMapper, SqlVerifyExplainService sqlVerifyExplainService) { + this.objectMapper = objectMapper; + this.sqlVerifyExplainService = sqlVerifyExplainService; + } + + @Override + public Map getToolCallbacks(String agentId) { + ToolDefinition toolDefinition = ToolDefinition.builder() + .name(TOOL_NAME) + .description(DESCRIPTION) + .inputSchema(INPUT_SCHEMA) + .build(); + return Map.of(TOOL_NAME, new SqlGuardToolCallback(toolDefinition, objectMapper, sqlVerifyExplainService)); + } + + private static final class SqlGuardToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + private final ObjectMapper objectMapper; + + private final SqlVerifyExplainService sqlVerifyExplainService; + + private SqlGuardToolCallback(ToolDefinition toolDefinition, ObjectMapper objectMapper, + SqlVerifyExplainService sqlVerifyExplainService) { + this.toolDefinition = toolDefinition; + this.objectMapper = objectMapper; + this.sqlVerifyExplainService = sqlVerifyExplainService; + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + try { + SqlGuardCheckRequest request = StringUtils.hasText(toolInput) + ? objectMapper.readValue(toolInput, SqlGuardCheckRequest.class) : new SqlGuardCheckRequest(); + return objectMapper.writeValueAsString(sqlVerifyExplainService.explain(request)); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to execute sql_guard.check: " + ex.getMessage(), ex); + } + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return call(toolInput); + } + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java new file mode 100644 index 000000000..b21ce7276 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -0,0 +1,711 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.fasterxml.jackson.databind.JsonNode; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.util.TablesNamesFinder; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; + +@Service +public class SqlVerifyExplainService { + + private static final Pattern AGGREGATE_PATTERN = Pattern + .compile("(?i)\\b(count|sum|avg|average|min|max)\\s*\\(([^)]*)\\)\\s*(?:as\\s+([a-zA-Z0-9_]+))?"); + + private static final Pattern GROUP_BY_PATTERN = Pattern.compile("(?is)\\bgroup\\s+by\\b"); + + private static final Pattern ORDER_BY_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b"); + + private static final Pattern LIMIT_PATTERN = Pattern.compile("(?is)\\blimit\\s+\\d+\\b"); + + private static final Pattern TOP_PATTERN = Pattern.compile("(?is)\\bselect\\s+top\\s+\\d+\\b"); + + private static final Pattern FETCH_FIRST_PATTERN = Pattern + .compile("(?is)\\bfetch\\s+first\\s+\\d+\\s+rows\\s+only\\b"); + + private static final Pattern DISTINCT_PATTERN = Pattern.compile("(?is)\\bselect\\s+distinct\\b|count\\s*\\(\\s*distinct\\b"); + + private static final Pattern DATE_LITERAL_PATTERN = Pattern + .compile("\\b\\d{4}[-/]\\d{1,2}[-/]\\d{1,2}\\b|\\b\\d{6,8}\\b"); + + private static final Pattern TIME_FUNCTION_PATTERN = Pattern.compile( + "(?is)\\b(current_date|current_timestamp|now\\s*\\(|curdate\\s*\\(|date\\s*\\(|date_trunc\\s*\\(|strftime\\s*\\(|to_date\\s*\\(|to_timestamp\\s*\\(|datediff\\s*\\(|dateadd\\s*\\(|timestampdiff\\s*\\(|interval\\b)"); + + private static final Pattern WHERE_PATTERN = Pattern.compile("(?is)\\bwhere\\b"); + + private static final Pattern DESC_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b.+?\\bdesc\\b"); + + private static final Pattern ASC_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b.+?\\basc\\b"); + + private static final Pattern TOP_N_QUERY_PATTERN = Pattern + .compile("(?i)(?:\\btop\\s*(\\d+)\\b|前\\s*(\\d+)\\s*(?:个|名|条)?|排名前\\s*(\\d+)\\s*(?:个|名)?)"); + + private static final Pattern SQL_LIMIT_VALUE_PATTERN = Pattern.compile("(?is)\\blimit\\s+(\\d+)\\b"); + + private static final Pattern SQL_TOP_VALUE_PATTERN = Pattern.compile("(?is)\\bselect\\s+top\\s+(\\d+)\\b"); + + private static final Pattern SQL_FETCH_FIRST_VALUE_PATTERN = Pattern + .compile("(?is)\\bfetch\\s+first\\s+(\\d+)\\s+rows\\s+only\\b"); + + public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { + String query = StringUtils.trimToEmpty(request == null ? null : request.getQuery()); + String sql = StringUtils.trimToEmpty(request == null ? null : request.getSql()); + if (StringUtils.isBlank(query)) { + throw new IllegalArgumentException("sql_guard.check 需要 query"); + } + if (StringUtils.isBlank(sql)) { + throw new IllegalArgumentException("sql_guard.check 需要 sql"); + } + + Statement statement; + try { + statement = parseSingleSelectStatement(sql); + } + catch (IllegalArgumentException ex) { + return SqlGuardCheckResult.builder() + .query(query) + .sql(sql) + .isAligned(false) + .summary("SQL 无法通过语法解析,无法继续做结构和意图一致性校验。") + .explainedIntent(buildIntentExplanation(analyzeQueryIntent(query))) + .problems(List.of(SqlGuardProblem.builder() + .code("SQL_PARSE_ERROR") + .title("SQL 语法解析失败") + .severity("high") + .message("SQL 无法解析,当前结果不能视为已校验通过。") + .why("校验器必须先把 SQL 解析成合法的 SELECT / WITH 语法树,才能继续检查聚合、分组、排序和时间窗口。") + .expected("输入应为单条可解析的只读 SELECT / WITH 查询。") + .actual("当前 SQL 在语法层面未通过解析。") + .evidence(ex.getMessage()) + .repairHint("先修复括号、关键字顺序、逗号、别名或多语句拼接问题,再重新调用 sql_guard.check。") + .build())) + .fixSuggestions(List.of("先修复 SQL 语法错误,再重新调用 sql_guard.check。")) + .ruleChecks(List.of(SqlGuardRuleCheck.builder() + .code("SQL_PARSE") + .title("SQL 语法解析") + .status("FAILED") + .detail("当前 SQL 未通过语法解析,后续结构规则无法继续执行。") + .evidence(ex.getMessage()) + .build())) + .build(); + } + + QueryIntent intent = analyzeQueryIntent(query); + SqlShape shape = analyzeSqlShape(statement, sql, request); + List problems = new ArrayList<>(); + Set fixSuggestions = new LinkedHashSet<>(); + List ruleChecks = new ArrayList<>(); + + evaluateAggregationRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateGroupingRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeFilterRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeBucketRule(sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeOrderRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateOrderingRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateLimitRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateDistinctRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateOrderDirectionRule(sql, intent, shape, problems, fixSuggestions, ruleChecks); + + boolean aligned = problems.stream().noneMatch(problem -> isBlockingSeverity(problem.getSeverity())); + String summary = aligned ? "SQL 通过了当前规则版意图一致性校验。" + : "检测到 %d 个可能影响答案正确性的意图一致性问题。".formatted(problems.size()); + if (aligned) { + fixSuggestions.add("当前规则校验通过;如要进一步提高置信度,可继续核对执行结果与最终答案解释。"); + } + return SqlGuardCheckResult.builder() + .query(query) + .sql(sql) + .isAligned(aligned) + .summary(summary) + .explainedIntent(buildIntentExplanation(intent)) + .problems(problems) + .fixSuggestions(List.copyOf(fixSuggestions)) + .usedTables(shape.usedTables()) + .usedMetrics(shape.usedMetrics()) + .ruleChecks(ruleChecks) + .build(); + } + + private void evaluateAggregationRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresAggregation()) { + return; + } + if (!shape.hasAggregation()) { + addProblem(problems, fixSuggestions, "MISSING_AGGREGATION", "缺少聚合指标", "high", + "问题看起来要求聚合指标,但 SQL 更像明细查询。", + "用户问题带有数量、金额、总数、平均值等聚合口径,但 SQL 没有检测到 count/sum/avg/min/max 等聚合函数。", + "SELECT 中应包含与题目口径匹配的聚合表达式。", + "当前 SQL 未检测到聚合函数。", + "query=" + query + "; sql=" + sql, + "把 count/sum/avg/min/max 等聚合逻辑补齐到 SELECT 中。"); + recordRuleCheck(ruleChecks, "AGGREGATION_REQUIRED", "聚合指标校验", "FAILED", + "问题要求聚合指标,但 SQL 未检测到聚合函数。", "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "AGGREGATION_REQUIRED", "聚合指标校验", "PASSED", + "问题要求聚合指标,SQL 已检测到聚合函数。", "usedMetrics=" + shape.usedMetrics()); + } + + private void evaluateGroupingRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresGrouping()) { + return; + } + if (!shape.hasGroupBy()) { + addProblem(problems, fixSuggestions, "MISSING_GROUP_BY", "缺少 GROUP BY", "high", + "问题要求按维度拆分,但 SQL 缺少 GROUP BY。", + "用户问题包含按地区、按用户、各品类、每月等拆分意图;没有 GROUP BY 时,要么结果被压成总计,要么数据库直接报错。", + "SQL 应按题目中的维度列做 GROUP BY。", + "当前 SQL 未检测到 GROUP BY。", + "query=" + query + "; sql=" + sql, + "把用户要求的维度列加入 GROUP BY,并检查 SELECT 中的非聚合列。"); + recordRuleCheck(ruleChecks, "GROUP_BY_REQUIRED", "分组维度校验", "FAILED", + "问题要求分维度拆分,但 SQL 未检测到 GROUP BY。", "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "GROUP_BY_REQUIRED", "分组维度校验", "PASSED", + "问题要求分维度拆分,SQL 已检测到 GROUP BY。", "sql=" + sql); + } + + private void evaluateTimeFilterRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTimeFilter()) { + return; + } + if (!shape.hasTimePredicate()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_FILTER", "缺少时间过滤", "high", + "问题包含明确时间窗口,但 SQL 没有可靠的时间过滤信号。", + "题目提到了今天、本月、最近30天、某年某月等时间范围,但 SQL 没有看到明确时间过滤,结果很可能回成全量数据。", + "WHERE 中应包含与题目对应的时间范围约束。", + "当前 SQL 未检测到可靠的时间过滤表达式。", + "query=" + query + "; sql=" + sql, + "在 WHERE 中补齐精确时间范围,不要按默认全量数据查询。"); + recordRuleCheck(ruleChecks, "TIME_FILTER_REQUIRED", "时间窗口校验", "FAILED", + "问题包含明确时间窗口,但 SQL 未检测到可靠时间过滤。", "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_FILTER_REQUIRED", "时间窗口校验", "PASSED", + "问题包含明确时间窗口,SQL 已检测到时间过滤。", "sql=" + sql); + } + + private void evaluateTimeBucketRule(String sql, QueryIntent intent, SqlShape shape, List problems, + Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTrend()) { + return; + } + if (!shape.hasTimeBucket()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_BUCKET", "缺少时间分桶", "high", + "趋势类问题通常需要时间分桶,但 SQL 没看到明确的时间粒度表达。", + "趋势分析需要先按天、周、月、年等粒度汇总;没有时间分桶,返回的往往只是总数,不是趋势。", + "SQL 应包含 DATE/DATE_TRUNC/DATE_FORMAT 等时间分桶表达式,并按同粒度分组。", + "当前 SQL 未检测到明确的时间分桶表达式。", + "sql=" + sql, + "用 DATE/DATE_TRUNC/DATE_FORMAT 等时间分桶表达式,并对同一时间粒度做 GROUP BY。"); + recordRuleCheck(ruleChecks, "TIME_BUCKET_REQUIRED", "趋势时间粒度校验", "FAILED", + "趋势问题需要时间分桶,但 SQL 未检测到时间粒度表达。", "sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_BUCKET_REQUIRED", "趋势时间粒度校验", "PASSED", + "趋势问题所需的时间分桶表达已检测到。", "sql=" + sql); + } + + private void evaluateTimeOrderRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTrend()) { + return; + } + if (!shape.hasOrderBy()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_ORDER", "缺少时间排序", "medium", + "趋势类问题通常需要按时间排序,但 SQL 缺少 ORDER BY。", + "趋势结果如果不按时间排序,输出顺序可能是乱的,后续回答和可视化都容易误导。", + "趋势 SQL 应按时间字段或时间分桶字段排序。", + "当前 SQL 未检测到 ORDER BY。", + "query=" + query + "; sql=" + sql, + "按时间字段或时间分桶字段补齐 ORDER BY。"); + recordRuleCheck(ruleChecks, "TIME_ORDER_REQUIRED", "趋势时间排序校验", "FAILED", + "趋势问题需要时间排序,但 SQL 未检测到 ORDER BY。", "sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_ORDER_REQUIRED", "趋势时间排序校验", "PASSED", + "趋势问题所需的时间排序已检测到。", "sql=" + sql); + } + + private void evaluateOrderingRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresOrdering()) { + return; + } + if (!shape.hasOrderBy()) { + addProblem(problems, fixSuggestions, "MISSING_ORDER_BY", "缺少排序", "medium", + "问题要求排序或排名,但 SQL 缺少 ORDER BY。", + "题目要求最高、最低、TopN、排名等比较关系;没有 ORDER BY 时,即使有限制行数,返回的也不一定是目标对象。", + "SQL 应明确按目标指标排序。", + "当前 SQL 未检测到 ORDER BY。", + "query=" + query + "; sql=" + sql, + "根据问题要求补齐 ORDER BY,并明确升序还是降序。"); + recordRuleCheck(ruleChecks, "ORDER_REQUIRED", "排序要求校验", "FAILED", + "问题包含排序或排名诉求,但 SQL 未检测到 ORDER BY。", "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "ORDER_REQUIRED", "排序要求校验", "PASSED", + "问题包含排序或排名诉求,SQL 已检测到 ORDER BY。", "sql=" + sql); + } + + private void evaluateLimitRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresLimit()) { + return; + } + if (!shape.hasLimit()) { + addProblem(problems, fixSuggestions, "MISSING_LIMIT", "缺少返回行数限制", "medium", + "问题要求 TopN / 前N / 单个最值对象,但 SQL 没有限制返回行数。", + "题目明确只需要前几名或唯一最值对象;如果不限制行数,结果会混入多余记录。", + "SQL 应通过 LIMIT / TOP / FETCH FIRST 控制返回行数。", + "当前 SQL 未检测到返回行数限制。", + "query=" + query + "; sql=" + sql, + "补齐 LIMIT / TOP / FETCH FIRST,避免把全量结果当成 TopN。"); + recordRuleCheck(ruleChecks, "LIMIT_REQUIRED", "TopN 行数限制校验", "FAILED", + "问题要求限制返回行数,但 SQL 未检测到 LIMIT/TOP/FETCH FIRST。", "query=" + query + "; sql=" + sql); + return; + } + if (shape.limitValueKnown() && intent.expectedLimit() != null && !intent.expectedLimit().equals(shape.limitValue())) { + String mismatchCode = shape.limitValue() > intent.expectedLimit() ? "LIMIT_TOO_LARGE" : "LIMIT_TOO_SMALL"; + addProblem(problems, fixSuggestions, mismatchCode, "TopN 数量不匹配", "medium", + "问题要求的返回条数与 SQL 实际限制条数不一致。", + "题目明确给了 TopN 数量或单个最值对象,限制条数不一致会直接改变结果口径。", + "返回条数应与题目要求一致。", + "题目期望 " + intent.expectedLimit() + " 条,但 SQL 当前限制为 " + shape.limitValue() + " 条。", + "query=" + query + "; sql=" + sql, + "把 LIMIT/TOP/FETCH FIRST 改成与题目一致的条数。"); + recordRuleCheck(ruleChecks, "LIMIT_MATCH", "TopN 数量匹配校验", "FAILED", + "题目期望 " + intent.expectedLimit() + " 条,但 SQL 实际限制为 " + shape.limitValue() + " 条。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "LIMIT_MATCH", "TopN 数量匹配校验", "PASSED", + intent.expectedLimit() == null ? "题目要求限制返回行数,SQL 已检测到限制。" + : "题目期望 " + intent.expectedLimit() + " 条,SQL 返回条数限制一致。", + "sql=" + sql); + } + + private void evaluateDistinctRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresDistinct()) { + return; + } + if (!shape.hasDistinct()) { + addProblem(problems, fixSuggestions, "MISSING_DISTINCT", "缺少 DISTINCT 去重", "high", + "问题要求去重口径,但 SQL 没看到 DISTINCT。", + "题目要求独立用户、去重人数、唯一值等口径;不去重会重复计算。", + "SQL 应使用 SELECT DISTINCT 或 COUNT(DISTINCT ...)。", + "当前 SQL 未检测到 DISTINCT。", + "query=" + query + "; sql=" + sql, + "把口径改成 SELECT DISTINCT 或 COUNT(DISTINCT ...)。"); + recordRuleCheck(ruleChecks, "DISTINCT_REQUIRED", "去重口径校验", "FAILED", + "问题要求去重口径,但 SQL 未检测到 DISTINCT。", "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "DISTINCT_REQUIRED", "去重口径校验", "PASSED", + "问题要求去重口径,SQL 已检测到 DISTINCT。", "sql=" + sql); + } + + private void evaluateOrderDirectionRule(String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!shape.hasOrderBy()) { + return; + } + if (intent.prefersDescending() && shape.orderDirectionKnown() && !shape.orderDescending()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_MISMATCH", "排序方向不匹配", "high", + "问题要求最高 / Top / 最多,但 SQL 排序方向不像降序。", + "题目要的是最大值或靠前排名,若排序方向写成 ASC,返回的会是最小值或反向结果。", + "这类问题通常应按目标指标 DESC 排序。", + "当前 SQL 的排序方向与题目诉求不一致。", + "sql=" + sql, + "把排序方向改成 DESC,并确认排序指标是否正确。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION", "排序方向校验", "FAILED", + "题目要求高到低/最多/Top,但 SQL 当前排序方向不是 DESC。", "sql=" + sql); + return; + } + if (intent.prefersAscending() && shape.orderDirectionKnown() && shape.orderDescending()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_MISMATCH", "排序方向不匹配", "high", + "问题要求最低 / 最少 / 最小,但 SQL 排序方向不像升序。", + "题目要的是最小值或最低排名,若排序方向写成 DESC,返回的会是最大值或反向结果。", + "这类问题通常应按目标指标 ASC 排序。", + "当前 SQL 的排序方向与题目诉求不一致。", + "sql=" + sql, + "把排序方向改成 ASC,并确认排序指标是否正确。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION", "排序方向校验", "FAILED", + "题目要求低到高/最少/最小,但 SQL 当前排序方向不是 ASC。", "sql=" + sql); + return; + } + if ((intent.prefersDescending() || intent.prefersAscending()) && !shape.orderDirectionKnown()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_AMBIGUOUS", "排序方向不明确", "medium", + "问题对排序方向有明确诉求,但 SQL 的 ORDER BY 没有写明 ASC / DESC。", + "有些数据库默认升序,但不应该依赖默认行为承载业务口径,否则最值问题很容易答反。", + "ORDER BY 应显式写明 ASC 或 DESC。", + "当前 SQL 虽然有 ORDER BY,但没有检测到明确排序方向。", + "sql=" + sql, + "显式补上 ASC 或 DESC,不要依赖数据库默认排序方向。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION_EXPLICIT", "显式排序方向校验", "FAILED", + "题目存在明确最值方向,但 SQL 的 ORDER BY 未显式写出 ASC/DESC。", "sql=" + sql); + return; + } + if (intent.prefersDescending() || intent.prefersAscending()) { + recordRuleCheck(ruleChecks, "ORDER_DIRECTION_EXPLICIT", "显式排序方向校验", "PASSED", + "ORDER BY 已显式声明排序方向。", "sql=" + sql); + } + } + + private void addProblem(List problems, Set fixSuggestions, String code, String title, + String severity, String message, String why, String expected, String actual, String evidence, + String fixSuggestion) { + problems.add(SqlGuardProblem.builder() + .code(code) + .title(title) + .severity(severity) + .message(message) + .why(why) + .expected(expected) + .actual(actual) + .evidence(evidence) + .repairHint(fixSuggestion) + .build()); + fixSuggestions.add(fixSuggestion); + } + + private void recordRuleCheck(List ruleChecks, String code, String title, String status, + String detail, String evidence) { + ruleChecks.add(SqlGuardRuleCheck.builder() + .code(code) + .title(title) + .status(status) + .detail(detail) + .evidence(evidence) + .build()); + } + + private boolean isBlockingSeverity(String severity) { + return "high".equalsIgnoreCase(severity) || "medium".equalsIgnoreCase(severity); + } + + private QueryIntent analyzeQueryIntent(String query) { + String normalized = query.toLowerCase(Locale.ROOT); + boolean requiresTrend = containsAny(query, "趋势", "走势图", "按天", "按周", "按月", "按年", "daily", "weekly", "monthly", + "trend", "over time", "环比", "同比"); + boolean requiresGrouping = requiresTrend + || containsAny(query, "按", "按照", "每个", "各", "分组", "group by", "维度", "分类", "分城市", "分地区", "分品类", "分渠道"); + boolean requiresTimeFilter = containsAny(query, "今天", "昨日", "昨天", "本周", "上周", "本月", "上月", "本季度", "上季度", "今年", + "去年", "近", "最近", "最近的", "latest", "recent", "last ", "past ", "today", "yesterday", "this month", + "this year", "202", "2024", "2025", "2026", "Q1", "Q2", "Q3", "Q4"); + boolean requiresOrdering = requiresTrend + || containsAny(normalized, "top ", "rank", "ranking", "highest", "lowest", "best", "worst", "most", "least") + || containsAny(query, "排名", "排行", "最高", "最低", "最多", "最少", "前", "后"); + boolean prefersDescending = containsAny(normalized, "top ", "highest", "best", "most", "largest") + || containsAny(query, "最高", "最多", "最大", "前"); + boolean prefersAscending = containsAny(normalized, "lowest", "least", "smallest", "worst") + || containsAny(query, "最低", "最少", "最小"); + Integer expectedLimit = extractExpectedLimit(query, prefersDescending, prefersAscending); + boolean requiresLimit = expectedLimit != null; + boolean requiresDistinct = containsAny(normalized, "distinct", "deduplicate", "unique", "uv") + || containsAny(query, "去重", "独立用户", "唯一"); + boolean requiresAggregation = requiresTrend || requiresDistinct + || containsAny(normalized, "count", "sum", "avg", "average", "total", "amount", "sales", "revenue") + || containsAny(query, "数量", "总数", "总额", "金额", "销量", "销售额", "订单数", "人数", "平均", "占比", "比例", "贡献", "多少"); + return new QueryIntent(requiresAggregation, requiresGrouping, requiresTimeFilter, requiresOrdering, requiresLimit, + requiresDistinct, requiresTrend, prefersDescending, prefersAscending, expectedLimit); + } + + private SqlShape analyzeSqlShape(Statement statement, String sql, SqlGuardCheckRequest request) { + String normalizedSql = StringUtils.trimToEmpty(sql).toLowerCase(Locale.ROOT); + Set knownTimeColumns = extractKnownTimeColumns(request); + List usedTables = extractReferencedTables(statement); + List usedMetrics = extractUsedMetrics(sql); + boolean hasAggregation = AGGREGATE_PATTERN.matcher(sql).find(); + boolean hasGroupBy = GROUP_BY_PATTERN.matcher(normalizedSql).find(); + boolean hasOrderBy = ORDER_BY_PATTERN.matcher(normalizedSql).find(); + boolean hasLimit = LIMIT_PATTERN.matcher(normalizedSql).find() || TOP_PATTERN.matcher(normalizedSql).find() + || FETCH_FIRST_PATTERN.matcher(normalizedSql).find(); + boolean hasDistinct = DISTINCT_PATTERN.matcher(normalizedSql).find(); + boolean hasTimePredicate = detectTimePredicate(normalizedSql, knownTimeColumns); + boolean hasTimeBucket = detectTimeBucket(normalizedSql, knownTimeColumns); + boolean orderDescending = DESC_PATTERN.matcher(normalizedSql).find(); + boolean orderDirectionKnown = DESC_PATTERN.matcher(normalizedSql).find() || ASC_PATTERN.matcher(normalizedSql).find(); + Integer limitValue = extractSqlLimitValue(normalizedSql); + boolean limitValueKnown = limitValue != null; + return new SqlShape(List.copyOf(usedTables), List.copyOf(usedMetrics), hasAggregation, hasGroupBy, hasOrderBy, + hasLimit, hasDistinct, hasTimePredicate, hasTimeBucket, orderDescending, orderDirectionKnown, + limitValueKnown, limitValue); + } + + private Integer extractSqlLimitValue(String normalizedSql) { + Integer limitValue = extractFirstInt(SQL_LIMIT_VALUE_PATTERN, normalizedSql); + if (limitValue != null) { + return limitValue; + } + limitValue = extractFirstInt(SQL_TOP_VALUE_PATTERN, normalizedSql); + if (limitValue != null) { + return limitValue; + } + return extractFirstInt(SQL_FETCH_FIRST_VALUE_PATTERN, normalizedSql); + } + + private boolean detectTimePredicate(String normalizedSql, Set knownTimeColumns) { + if (!WHERE_PATTERN.matcher(normalizedSql).find()) { + return false; + } + if (DATE_LITERAL_PATTERN.matcher(normalizedSql).find() || TIME_FUNCTION_PATTERN.matcher(normalizedSql).find()) { + return true; + } + if (containsAny(normalizedSql, " created_at ", " create_time ", " updated_at ", " update_time ", " order_date ", + " biz_date ", " stat_date ", " date ", " time ", " month ", " year ", " day ")) { + return true; + } + return knownTimeColumns.stream().anyMatch(column -> normalizedSql.contains(column.toLowerCase(Locale.ROOT))); + } + + private boolean detectTimeBucket(String normalizedSql, Set knownTimeColumns) { + if (!GROUP_BY_PATTERN.matcher(normalizedSql).find()) { + return false; + } + if (containsAny(normalizedSql, "date_trunc(", "date(", "strftime(", "to_date(", "extract(", "year(", "month(", "day(")) { + return true; + } + if (containsAny(normalizedSql, " by day", " by month", " by week", " by year")) { + return true; + } + return knownTimeColumns.stream().anyMatch(column -> normalizedSql.contains(column.toLowerCase(Locale.ROOT))); + } + + private Set extractKnownTimeColumns(SqlGuardCheckRequest request) { + Set columns = new LinkedHashSet<>(); + if (request == null) { + return columns; + } + collectTimeColumns(request.getTableSchemas(), columns); + collectTimeColumns(request.getSemanticHits(), columns); + collectTimeColumns(request.getBusinessKnowledgeHits(), columns); + return columns; + } + + private void collectTimeColumns(JsonNode node, Set columns) { + if (node == null || node.isNull()) { + return; + } + if (node.isTextual()) { + String value = node.asText(); + if (isLikelyTimeColumn(value)) { + columns.add(value.toLowerCase(Locale.ROOT)); + } + return; + } + if (node.isArray()) { + for (JsonNode item : node) { + collectTimeColumns(item, columns); + } + return; + } + if (!node.isObject()) { + return; + } + node.fields().forEachRemaining(entry -> { + String fieldName = entry.getKey(); + JsonNode value = entry.getValue(); + if (value != null && value.isTextual() + && ("name".equalsIgnoreCase(fieldName) || "columnName".equalsIgnoreCase(fieldName) + || "fieldName".equalsIgnoreCase(fieldName) || "column".equalsIgnoreCase(fieldName)) + && isLikelyTimeColumn(value.asText())) { + columns.add(value.asText().toLowerCase(Locale.ROOT)); + } + collectTimeColumns(value, columns); + }); + } + + private boolean isLikelyTimeColumn(String value) { + String normalized = StringUtils.trimToEmpty(value).toLowerCase(Locale.ROOT); + return containsAny(normalized, "date", "time", "day", "week", "month", "year", "created", "updated", "dt", + "biz_date", "stat_date", "order_date"); + } + + private List extractReferencedTables(Statement statement) { + try { + return new TablesNamesFinder().getTableList(statement) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + } + catch (Exception ex) { + return List.of(); + } + } + + private List extractUsedMetrics(String sql) { + Set metrics = new LinkedHashSet<>(); + Matcher matcher = AGGREGATE_PATTERN.matcher(StringUtils.defaultString(sql)); + while (matcher.find()) { + String alias = matcher.group(3); + if (StringUtils.isNotBlank(alias)) { + metrics.add(alias.trim()); + continue; + } + String functionName = Objects.toString(matcher.group(1), "").toUpperCase(Locale.ROOT); + String argument = Objects.toString(matcher.group(2), "").trim(); + metrics.add(functionName + "(" + argument + ")"); + } + return List.copyOf(metrics); + } + + private Statement parseSingleSelectStatement(String sql) { + String normalizedSql = stripTrailingSemicolons(sql); + if (normalizedSql.isEmpty()) { + throw new IllegalArgumentException("SQL 不能为空"); + } + try { + List statements = CCJSqlParserUtil.parseStatements(normalizedSql).getStatements(); + if (statements == null || statements.isEmpty()) { + throw new IllegalArgumentException("SQL 不能为空"); + } + if (statements.size() > 1) { + throw new IllegalArgumentException("仅支持单条 SELECT / WITH 查询"); + } + Statement statement = statements.get(0); + if (!(statement instanceof Select)) { + throw new IllegalArgumentException("sql_guard.check 仅校验 SELECT / WITH 查询"); + } + return statement; + } + catch (IllegalArgumentException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalArgumentException("SQL 解析失败,请检查语法后重试", ex); + } + } + + private String stripTrailingSemicolons(String sql) { + String trimmed = StringUtils.trimToEmpty(sql); + while (trimmed.endsWith(";")) { + trimmed = trimmed.substring(0, trimmed.length() - 1).trim(); + } + return trimmed; + } + + private String buildIntentExplanation(QueryIntent intent) { + List fragments = new ArrayList<>(); + if (intent.requiresAggregation()) { + fragments.add("问题包含聚合指标诉求"); + } + if (intent.requiresGrouping()) { + fragments.add("问题包含按维度拆分诉求"); + } + if (intent.requiresTimeFilter()) { + fragments.add("问题包含明确时间窗口"); + } + if (intent.requiresTrend()) { + fragments.add("问题包含趋势或时间序列分析"); + } + if (intent.requiresOrdering()) { + fragments.add("问题包含排序或排名要求"); + } + if (intent.requiresLimit()) { + fragments.add("问题包含 TopN / Top1 行数限制"); + } + if (intent.requiresDistinct()) { + fragments.add("问题包含去重口径"); + } + if (fragments.isEmpty()) { + return "当前规则没有识别到强约束口径,主要执行基础 SQL 结构检查。"; + } + return String.join(";", fragments) + "。"; + } + + private boolean containsAny(String value, String... needles) { + if (value == null) { + return false; + } + for (String needle : needles) { + if (needle != null && value.contains(needle)) { + return true; + } + } + return false; + } + + private Integer extractExpectedLimit(String query, boolean prefersDescending, boolean prefersAscending) { + String safeQuery = StringUtils.defaultString(query); + Matcher matcher = TOP_N_QUERY_PATTERN.matcher(safeQuery); + if (matcher.find()) { + for (int index = 1; index <= matcher.groupCount(); index++) { + String group = matcher.group(index); + if (StringUtils.isNotBlank(group)) { + try { + return Integer.parseInt(group.trim()); + } + catch (NumberFormatException ex) { + return null; + } + } + } + } + boolean asksSingleTarget = containsAny(safeQuery, "哪个", "哪位", "哪一个", "谁"); + if (asksSingleTarget && (prefersDescending || prefersAscending)) { + return 1; + } + return null; + } + + private Integer extractFirstInt(Pattern pattern, String value) { + Matcher matcher = pattern.matcher(StringUtils.defaultString(value)); + if (!matcher.find()) { + return null; + } + String group = matcher.group(1); + if (StringUtils.isBlank(group)) { + return null; + } + try { + return Integer.parseInt(group.trim()); + } + catch (NumberFormatException ex) { + return null; + } + } + + private record QueryIntent(boolean requiresAggregation, boolean requiresGrouping, boolean requiresTimeFilter, + boolean requiresOrdering, boolean requiresLimit, boolean requiresDistinct, boolean requiresTrend, + boolean prefersDescending, boolean prefersAscending, Integer expectedLimit) { + } + + private record SqlShape(List usedTables, List usedMetrics, boolean hasAggregation, boolean hasGroupBy, + boolean hasOrderBy, boolean hasLimit, boolean hasDistinct, boolean hasTimePredicate, boolean hasTimeBucket, + boolean orderDescending, boolean orderDirectionKnown, boolean limitValueKnown, Integer limitValue) { + } + +} diff --git a/data-agent-management/src/main/resources/prompts/commonagent.md b/data-agent-management/src/main/resources/prompts/commonagent.md index 9c04d9c64..df781219e 100644 --- a/data-agent-management/src/main/resources/prompts/commonagent.md +++ b/data-agent-management/src/main/resources/prompts/commonagent.md @@ -7,3 +7,14 @@ 3. 如果问题属于业务定义、指标口径、SOP、FAQ、历史案例、领域术语,而不是表结构本身,就使用 `domain_business_knowledge.search`。 4. 如果用户问的是表名、列名、字段类型、枚举值、表关系、字段关系,不要先调用 `domain_business_knowledge.search`。 先检查 datasource explorer;如果 datasource explorer 还不够,再考虑 `semantic_model.search`。 +5. 如果你已经准备了候选 SQL,且答案将基于 SQL 返回给用户,在执行 SQL 前先调用一次 `sql_guard.check`。 + 必须传顶层字段:`query`、`sql`;可选再传 `tableSchemas`、`semanticHits`、`businessKnowledgeHits`。 +6. `sql_guard.check` 只做结构与意图校验,不负责自动修复、不负责执行报错修复、也不负责结果回看。 + 重点校验:指标是否对题、是否缺少 `GROUP BY`、时间窗口是否完整、排序 / TopN 是否正确、是否遗漏 `DISTINCT`。 +7. 读取 `sql_guard.check` 结果时,直接看顶层字段:`isAligned`、`problems`、`ruleChecks`、`fixSuggestions`、`summary`。 + `problems` 里会给出为什么错、期望是什么、实际检测到什么、建议怎么修;`ruleChecks` 用来解释这次到底检查了哪些规则、每条规则是通过还是失败。 +8. 如果 `sql_guard.check` 返回 `isAligned=false`,必须根据 `problems` 和 `fixSuggestions` 自己改写 SQL,然后把新的候选 SQL 再次传给 `sql_guard.check`。 + 不要把上一次 `sql_guard.check` 的输出对象原样回传给工具;每次都要重新传顶层 `query` 和新的 `sql`。 +9. 只有当 `sql_guard.check` 返回 `isAligned=true` 后,才能执行 datasource explorer 的 `SEARCH`。 +10. 如果 SQL 执行报错,或者执行结果看起来不合理,由 agent 根据数据库报错或结果样例自行分析并重写 SQL,再重新走 `sql_guard.check`。 + 不要调用额外的 SQL 自动修复工具。 From dea0549bd78dd3b02548692e40f323a577640c48 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 22 Apr 2026 22:24:05 +0800 Subject: [PATCH 03/22] fix: stream build bug --- .../impl/AiAgentRuntimeServiceImpl.java | 41 +++++++++++++++++-- .../dataagent/controller/ChatController.java | 16 +++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 43f7e707c..3dbb24896 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -151,7 +151,7 @@ private void emitSuccess(Sinks.Many> sink, Gr } private boolean shouldEmitFinalResponse(String result, StreamTextTracker streamTextTracker) { - return StringUtils.hasText(result) && !streamTextTracker.matchesAnyNodeAccumulation(result); + return StringUtils.hasText(result) && !streamTextTracker.containsFinalAnswer(result); } private void emitError(Sinks.Many> sink, GraphRequest request, Throwable error) { @@ -314,25 +314,58 @@ private static final class StreamTextTracker { private final Map accumulatedByNode = new LinkedHashMap<>(); + private final Map lastTextByNode = new LinkedHashMap<>(); + synchronized void record(String nodeName, String text) { if (!StringUtils.hasText(text)) { return; } - accumulatedByNode.computeIfAbsent(nodeName, key -> new StringBuilder()).append(text); + String normalizedNodeName = StringUtils.hasText(nodeName) ? nodeName : ""; + accumulatedByNode.computeIfAbsent(normalizedNodeName, key -> new StringBuilder()).append(text); + lastTextByNode.put(normalizedNodeName, text); } - synchronized boolean matchesAnyNodeAccumulation(String candidate) { + synchronized boolean containsFinalAnswer(String candidate) { if (!StringUtils.hasText(candidate)) { return false; } + String normalizedCandidate = normalize(candidate); + if (!StringUtils.hasText(normalizedCandidate)) { + return false; + } for (StringBuilder accumulated : accumulatedByNode.values()) { - if (candidate.equals(accumulated.toString())) { + String normalizedAccumulated = normalize(accumulated.toString()); + if (matchesFinalAnswer(normalizedAccumulated, normalizedCandidate)) { + return true; + } + } + for (String lastText : lastTextByNode.values()) { + String normalizedLastText = normalize(lastText); + if (matchesFinalAnswer(normalizedLastText, normalizedCandidate)) { return true; } } return false; } + private boolean matchesFinalAnswer(String existingText, String candidate) { + if (!StringUtils.hasText(existingText) || !StringUtils.hasText(candidate)) { + return false; + } + return existingText.equals(candidate) || existingText.endsWith(candidate) || candidate.endsWith(existingText); + } + + private String normalize(String text) { + if (!StringUtils.hasText(text)) { + return ""; + } + return text.replace("\r\n", "\n") + .replace('\r', '\n') + .replaceAll("[ \\t\\x0B\\f]+", " ") + .replaceAll(" *\\n *", "\n") + .trim(); + } + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java index c02a9800b..82b931800 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java @@ -141,7 +141,7 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" // Update session activity time chatSessionService.updateSessionTime(sessionId); - if (request.isTitleNeeded()) { + if (shouldGenerateTitle(request, savedMessage)) { sessionTitleService.scheduleTitleGeneration(sessionId, message.getContent()); } @@ -233,4 +233,18 @@ public ResponseEntity convertAndDownloadHtml(@PathVariable(value = "sess } } + private boolean shouldGenerateTitle(ChatMessageDTO request, ChatMessage savedMessage) { + if (request == null || savedMessage == null) { + return false; + } + if (request.isTitleNeeded()) { + return true; + } + if (!"user".equalsIgnoreCase(request.getRole())) { + return false; + } + List sessionMessages = chatMessageService.findBySessionId(savedMessage.getSessionId()); + return sessionMessages.size() == 1; + } + } From ea3dd5e2abc01ea8bd7d575debf3cf8d19595631 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Thu, 23 Apr 2026 17:57:53 +0800 Subject: [PATCH 04/22] fix: add trace --- data-agent-frontend/knip.json | 3 +- data-agent-frontend/src/services/chat.ts | 31 ++ data-agent-frontend/src/views/AgentRun.vue | 286 +++++++++++++++- data-agent-management/pom.xml | 88 +---- .../impl/AiAgentRuntimeServiceImpl.java | 99 ++++-- .../datasource/DatasourceExplorerService.java | 6 +- .../AgentScopeTracingConfiguration.java | 77 +++++ .../dataagent/config/OpenTelemetryConfig.java | 83 +++-- .../connector/accessor/AbstractAccessor.java | 20 +- .../dataagent/controller/ChatController.java | 10 + .../observability/SessionTraceStore.java | 305 ++++++++++++++++++ .../AgentScopeObservabilityProperties.java | 40 +++ .../service/langfuse/LangfuseService.java | 4 +- .../src/main/resources/application.yml | 4 + pom.xml | 24 +- 15 files changed, 910 insertions(+), 170 deletions(-) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/properties/AgentScopeObservabilityProperties.java diff --git a/data-agent-frontend/knip.json b/data-agent-frontend/knip.json index afbe789ed..d06c3bc8e 100644 --- a/data-agent-frontend/knip.json +++ b/data-agent-frontend/knip.json @@ -2,6 +2,5 @@ "$schema": "https://unpkg.com/knip@4/schema.json", "entry": ["src/main.js", "src/**/*.vue", "vite.config.js"], "project": ["src/**/*.{js,ts,vue}"], - "ignore": ["**/*.d.ts"], - "ignoreDependencies": ["@vitejs/plugin-vue", "vite"] + "ignore": ["**/*.d.ts"] } diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index 66317b88a..b3bf8c14c 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -39,6 +39,32 @@ export interface ChatMessage { titleNeeded?: boolean; } +export interface TraceSpan { + name: string; + spanId: string; + parentSpanId: string; + kind: string; + status: string; + startEpochMs: number; + endEpochMs: number; + durationMs: number; + attributes: Record; + children: TraceSpan[]; +} + +export interface SessionTrace { + sessionId: string; + traceId: string; + runtimeRequestId: string; + agentId: string; + startEpochMs: number; + endEpochMs: number; + durationMs: number; + spanCount: number; + rootSpan: TraceSpan | null; + rootSpans: TraceSpan[]; +} + const API_BASE_URL = '/api'; class ChatService { @@ -90,6 +116,11 @@ class ChatService { return response.data; } + async getSessionTrace(sessionId: string): Promise { + const response = await axios.get(`${API_BASE_URL}/sessions/${sessionId}/trace`); + return response.data; + } + /** * 保存消息到会话 * @param sessionId 会话ID diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 02d66893d..0a4d8bea2 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -281,6 +281,16 @@
+ + Trace +
+ + +
+
+ Trace ID: {{ sessionTrace.traceId }} + Span 数: {{ sessionTrace.spanCount }} + 耗时: {{ formatTraceDuration(sessionTrace.durationMs) }} + 开始时间: {{ formatTraceTime(sessionTrace.startEpochMs) }} +
+ 刷新 +
+ + + + +
+
+
+ {{ row.span.name }} + {{ row.span.kind }} + + {{ row.span.status }} + + {{ formatTraceDuration(row.span.durationMs) }} +
+
+ spanId: {{ row.span.spanId }} + parent: {{ row.span.parentSpanId || '-' }} + 开始: {{ formatTraceTime(row.span.startEpochMs) }} +
+
+ 属性 +
+
+ {{ entry.key }} + {{ entry.value }} +
+
+
+
+
+
@@ -382,7 +460,12 @@ hljs.registerLanguage('json', json); import BaseLayout from '@/layouts/BaseLayout.vue'; import AgentService from '@/services/agent'; - import ChatService, { type ChatSession, type ChatMessage } from '@/services/chat'; + import ChatService, { + type ChatSession, + type ChatMessage, + type SessionTrace, + type TraceSpan, + } from '@/services/chat'; import GraphService, { type GraphRequest, type GraphNodeResponse, @@ -519,6 +602,10 @@ const showReportFullscreen = ref(false); const fullscreenReportContent = ref(''); const inputControlsCollapsed = ref(false); + const traceDialogVisible = ref(false); + const traceLoading = ref(false); + const traceError = ref(''); + const sessionTrace = ref(null); // 监听NL2SQL开关变化 const handleNl2sqlOnlyChange = (value: boolean) => { @@ -562,6 +649,9 @@ saveViewToState(currentSession.value.id, { isStreaming, nodeBlocks }); } currentSession.value = session; + sessionTrace.value = null; + traceError.value = ''; + traceDialogVisible.value = false; try { if (session === null) { @@ -1153,6 +1243,79 @@ sessionState.markdownReportContent = ''; }; + const flattenTraceSpans = ( + spans: TraceSpan[], + depth = 0, + ): Array<{ span: TraceSpan; depth: number; attributeEntries: Array<{ key: string; value: string }> }> => { + return spans.flatMap(span => { + const attributeEntries = Object.entries(span.attributes ?? {}).map(([key, value]) => ({ + key, + value, + })); + return [ + { + span, + depth, + attributeEntries, + }, + ...flattenTraceSpans(span.children ?? [], depth + 1), + ]; + }); + }; + + const flattenedTraceSpans = computed(() => + sessionTrace.value ? flattenTraceSpans(sessionTrace.value.rootSpans ?? []) : [], + ); + + const formatTraceDuration = (durationMs: number) => { + if (durationMs < 1000) { + return `${durationMs} ms`; + } + if (durationMs < 10000) { + return `${(durationMs / 1000).toFixed(2)} s`; + } + return `${(durationMs / 1000).toFixed(1)} s`; + }; + + const formatTraceTime = (epochMs: number) => { + if (!epochMs) { + return '--'; + } + return new Date(epochMs).toLocaleString(); + }; + + const loadLatestTrace = async () => { + if (!currentSession.value) { + sessionTrace.value = null; + traceError.value = '当前没有可查看 trace 的会话'; + return; + } + traceLoading.value = true; + traceError.value = ''; + try { + sessionTrace.value = await ChatService.getSessionTrace(currentSession.value.id); + } catch (error: any) { + sessionTrace.value = null; + if (error?.response?.status === 404) { + traceError.value = '当前会话还没有最近一次 trace,请先执行一轮对话。'; + } else { + traceError.value = '加载 trace 失败,请稍后重试。'; + } + console.error('加载 trace 失败:', error); + } finally { + traceLoading.value = false; + } + }; + + const openTraceDialog = async () => { + traceDialogVisible.value = true; + await loadLatestTrace(); + }; + + const refreshTrace = async () => { + await loadLatestTrace(); + }; + const scrollToBottom = () => { nextTick(() => { if (chatContainer.value) { @@ -1398,6 +1561,10 @@ showReportFullscreen, fullscreenReportContent, inputControlsCollapsed, + traceDialogVisible, + traceLoading, + traceError, + sessionTrace, autoScroll, chatContainer, nodeBlocks, @@ -1406,6 +1573,9 @@ lastRequest, resultSetDisplayConfig, options, + flattenedTraceSpans, + formatTraceDuration, + formatTraceTime, getMarkdownContentFromNode, selectSession, sendMessage, @@ -1422,6 +1592,8 @@ handleHumanFeedback, handlePresetQuestionClick, stopStreaming, + openTraceDialog, + refreshTrace, deleteSessionState, }; }, @@ -1828,6 +2000,109 @@ align-items: flex-end; } + .trace-button { + height: 40px; + padding: 0 14px; + border-radius: 999px; + flex-shrink: 0; + } + + .trace-toolbar { + display: flex; + justify-content: space-between; + align-items: flex-start; + gap: 16px; + margin-bottom: 16px; + } + + .trace-summary { + display: flex; + flex-wrap: wrap; + gap: 12px; + color: #606266; + font-size: 13px; + } + + .trace-alert { + margin-bottom: 16px; + } + + .trace-list { + display: flex; + flex-direction: column; + gap: 12px; + max-height: 65vh; + overflow: auto; + padding-right: 6px; + } + + .trace-row { + border: 1px solid #ebeef5; + border-radius: 12px; + padding: 12px; + background: #fff; + } + + .trace-row-main { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; + } + + .trace-row-name { + font-weight: 600; + color: #303133; + } + + .trace-row-duration { + color: #409eff; + font-size: 12px; + } + + .trace-row-meta { + display: flex; + flex-wrap: wrap; + gap: 12px; + margin-top: 8px; + color: #909399; + font-size: 12px; + word-break: break-all; + } + + .trace-attributes { + margin-top: 10px; + } + + .trace-attributes summary { + cursor: pointer; + color: #606266; + font-size: 13px; + } + + .trace-attributes-grid { + display: grid; + grid-template-columns: minmax(180px, 220px) minmax(0, 1fr); + gap: 8px 12px; + margin-top: 10px; + } + + .trace-attribute-item { + display: contents; + } + + .trace-attribute-key { + color: #606266; + font-size: 12px; + word-break: break-all; + } + + .trace-attribute-value { + color: #303133; + font-size: 12px; + word-break: break-all; + } + @keyframes spin { from { transform: rotate(0deg); @@ -1850,6 +2125,15 @@ .input-container { flex-direction: column; } + + .trace-toolbar { + flex-direction: column; + align-items: stretch; + } + + .trace-attributes-grid { + grid-template-columns: 1fr; + } } diff --git a/data-agent-management/pom.xml b/data-agent-management/pom.xml index a9b2223b1..b6aeb27be 100644 --- a/data-agent-management/pom.xml +++ b/data-agent-management/pom.xml @@ -19,15 +19,6 @@ - - org.springframework.boot - spring-boot-starter - - - org.mybatis.spring.boot - mybatis-spring-boot-starter-test - test - org.mybatis.spring.boot mybatis-spring-boot-starter @@ -46,10 +37,7 @@ org.springframework.boot spring-boot-starter-test - - - com.aliyun - gpdb20160503 + test org.springframework.boot @@ -62,19 +50,6 @@ spring-boot-starter-validation - - - org.springframework.ai - spring-ai-transformers - true - - - - io.micrometer - micrometer-observation-test - test - - org.junit.jupiter junit-jupiter-api @@ -179,21 +154,6 @@ ${springdoc-openapi.version} - - org.springframework.boot - spring-boot-testcontainers - test - - - org.testcontainers - testcontainers - test - - - org.testcontainers - junit-jupiter - test - io.opentelemetry @@ -208,30 +168,11 @@ opentelemetry-exporter-otlp - io.opentelemetry - opentelemetry-sdk-extension-autoconfigure - - - - - org.testcontainers - mysql - test - - - - - io.projectreactor - reactor-test - test + io.opentelemetry.instrumentation + opentelemetry-reactor-3.1 + ${opentelemetry-instrumentation.version} - - - org.awaitility - awaitility - test - io.netty @@ -275,18 +216,6 @@ spring-ai-alibaba-dashscope ${spring-ai-alibaba.version} - - com.github.victools - jsonschema-generator - compile - - - - com.github.victools - jsonschema-module-jackson - compile - - org.springframework.boot spring-boot-starter-webflux @@ -298,11 +227,6 @@ lombok - - com.atlassian.commonmark - commonmark - - org.springframework.boot spring-boot-starter-jdbc @@ -330,10 +254,6 @@ ${elasticsearch-client.version} - - jakarta.validation - jakarta.validation-api - org.springframework.ai spring-ai-tika-document-reader diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 3dbb24896..530653df3 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -32,10 +32,15 @@ import com.alibaba.cloud.ai.dataagent.enums.TextType; import com.alibaba.cloud.ai.dataagent.dto.ModelConfigDTO; import com.alibaba.cloud.ai.dataagent.entity.Agent; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import com.alibaba.cloud.ai.dataagent.service.aimodelconfig.DynamicModelFactory; import com.alibaba.cloud.ai.dataagent.service.aimodelconfig.ModelConfigDataService; import io.agentscope.core.message.Msg; import io.agentscope.core.model.Model; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; import java.util.LinkedHashMap; import java.util.Map; import java.util.UUID; @@ -43,6 +48,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.tool.ToolCallback; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.http.codec.ServerSentEvent; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; @@ -63,6 +69,8 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private static final String STREAM_EVENT_MESSAGE = "message"; + private static final String ROOT_SPAN_NAME = "data-agent.agent.run"; + private static final String AGENT_STATUS_PUBLISHED = "published"; private static final String AGENT_STATUS_OFFLINE = "offline"; @@ -83,6 +91,9 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private final com.alibaba.cloud.ai.dataagent.service.agent.AgentService agentService; + @Qualifier("agentScopeTracer") + private final Tracer tracer; + @Override public String nl2sql(String naturalQuery, String agentId) { log.info("NL2SQL runtime invoked for agentId={}", agentId); @@ -187,45 +198,75 @@ private void initializeRuntimeRequest(GraphRequest request) { private String executeAgent(GraphRequest request, AgentRuntimeEventPublisher eventPublisher) { sessionRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); + Span rootSpan = startRuntimeSpan(request); try { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { - return ""; - } - Agent managedAgentConfig = resolveManagedAgent(request.getAgentId()); - ModelConfigDTO modelConfig = modelConfigDataService.getActiveConfigByType(ModelType.CHAT); - validateModelConfig(modelConfig); - Map toolCallbacks = agentScopeToolkitFactory.getToolCallbacks(request.getAgentId()); - Model model = agentScopeModelFactory.create(dynamicModelFactory.createChatModel(modelConfig), - modelConfig.getModelName(), toolCallbacks); - ManagedAgent managedAgent = managedAgentRegistry.getRequired(); - AgentRuntimeExtensions runtimeExtensions = agentRuntimeExtensionFactory.create(request, eventPublisher, - toolCallbacks); - Msg response; - try { - response = managedAgent.run(new AgentRunContext(request.getAgentId(), request.getThreadId(), model, - resolveManagedSystemPrompt(managedAgentConfig, request.getAgentId()), buildUserPrompt(request), - AgentRuntimeConstant.AGENT_CALL_TIMEOUT, runtimeExtensions)); - } - catch (RuntimeException ex) { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) - && isInterruptedCancellation(ex)) { - Thread.interrupted(); - log.info("Agent execution interrupted by cancellation, threadId={}, runtimeRequestId={}", - request.getThreadId(), request.getRuntimeRequestId()); + try (Scope ignored = rootSpan.makeCurrent()) { + if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + rootSpan.setStatus(StatusCode.OK, "cancelled"); return ""; } - throw ex; - } - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { - return ""; + Agent managedAgentConfig = resolveManagedAgent(request.getAgentId()); + ModelConfigDTO modelConfig = modelConfigDataService.getActiveConfigByType(ModelType.CHAT); + validateModelConfig(modelConfig); + Map toolCallbacks = agentScopeToolkitFactory.getToolCallbacks(request.getAgentId()); + Model model = agentScopeModelFactory.create(dynamicModelFactory.createChatModel(modelConfig), + modelConfig.getModelName(), toolCallbacks); + ManagedAgent managedAgent = managedAgentRegistry.getRequired(); + AgentRuntimeExtensions runtimeExtensions = agentRuntimeExtensionFactory.create(request, eventPublisher, + toolCallbacks); + Msg response; + try { + response = managedAgent.run(new AgentRunContext(request.getAgentId(), request.getThreadId(), model, + resolveManagedSystemPrompt(managedAgentConfig, request.getAgentId()), buildUserPrompt(request), + AgentRuntimeConstant.AGENT_CALL_TIMEOUT, runtimeExtensions)); + } + catch (RuntimeException ex) { + if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) + && isInterruptedCancellation(ex)) { + Thread.interrupted(); + rootSpan.setStatus(StatusCode.OK, "cancelled"); + log.info("Agent execution interrupted by cancellation, threadId={}, runtimeRequestId={}", + request.getThreadId(), request.getRuntimeRequestId()); + return ""; + } + throw ex; + } + if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + rootSpan.setStatus(StatusCode.OK, "cancelled"); + return ""; + } + rootSpan.setStatus(StatusCode.OK); + return extractText(response); } - return extractText(response); + } + catch (RuntimeException ex) { + recordRuntimeFailure(rootSpan, ex); + throw ex; } finally { + rootSpan.end(); sessionRegistry.clearRunning(request.getThreadId(), request.getRuntimeRequestId()); } } + private Span startRuntimeSpan(GraphRequest request) { + Span span = tracer.spanBuilder(ROOT_SPAN_NAME).startSpan(); + span.setAttribute(SessionTraceStore.ATTR_THREAD_ID, request.getThreadId()); + span.setAttribute(SessionTraceStore.ATTR_RUNTIME_REQUEST_ID, request.getRuntimeRequestId()); + span.setAttribute(SessionTraceStore.ATTR_AGENT_ID, request.getAgentId() == null ? "" : request.getAgentId()); + span.setAttribute("dataagent.runtime.human_feedback", request.isHumanFeedback()); + span.setAttribute("dataagent.runtime.nl2sql_only", request.isNl2sqlOnly()); + return span; + } + + private void recordRuntimeFailure(Span rootSpan, Throwable throwable) { + if (rootSpan == null) { + return; + } + rootSpan.setStatus(StatusCode.ERROR, throwable.getMessage() == null ? "runtime failed" : throwable.getMessage()); + rootSpan.recordException(throwable); + } + private void validateModelConfig(ModelConfigDTO modelConfig) { if (modelConfig == null) { throw new IllegalStateException("No active CHAT model configured. Please configure it in the dashboard."); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index 07f8013a0..099b9d1b9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -238,7 +238,7 @@ private ExplorerContext resolveContext(String agentId) throws Exception { .map(this::normalizeTableName) .collect(Collectors.toCollection(LinkedHashSet::new)); List logicalRelations = datasourceService.getLogicalRelations(datasource.getId()); - List physicalRelations = loadPhysicalRelations(accessor, dbConfig); + List physicalRelations = loadPhysicalRelations(accessor, dbConfig, visibleTables); List unifiedRelations = buildUnifiedRelations(visibleTableNameSet, physicalRelations, logicalRelations == null ? List.of() : logicalRelations); return new ExplorerContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), @@ -246,10 +246,10 @@ private ExplorerContext resolveContext(String agentId) throws Exception { List.copyOf(unifiedRelations), indexRelationsByTable(unifiedRelations)); } - private List loadPhysicalRelations(Accessor accessor, DbConfigBO dbConfig) { + private List loadPhysicalRelations(Accessor accessor, DbConfigBO dbConfig, List tables) { try { List foreignKeys = accessor.showForeignKeys(dbConfig, - DbQueryParameter.from(dbConfig).setSchema(dbConfig.getSchema())); + DbQueryParameter.from(dbConfig).setSchema(dbConfig.getSchema()).setTables(tables)); return foreignKeys == null ? List.of() : foreignKeys; } catch (Exception ex) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java new file mode 100644 index 000000000..da2d5a3de --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.config; + +import com.alibaba.cloud.ai.dataagent.properties.AgentScopeObservabilityProperties; +import io.agentscope.core.tracing.TracerRegistry; +import io.agentscope.core.tracing.telemetry.TelemetryTracer; +import io.opentelemetry.api.trace.Tracer; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; + +@Slf4j +@Configuration +@RequiredArgsConstructor +@EnableConfigurationProperties(AgentScopeObservabilityProperties.class) +public class AgentScopeTracingConfiguration implements SmartInitializingSingleton { + + private final AgentScopeObservabilityProperties properties; + + private final OpenTelemetryConfig openTelemetryConfig; + + @Qualifier("langfuseTracer") + private final Tracer langfuseTracer; + + @Qualifier("agentScopeLocalTracer") + private final Tracer agentScopeLocalTracer; + + @Bean("agentScopeTracer") + @Primary + public Tracer agentScopeTracer() { + return selectTracer(); + } + + @Override + public void afterSingletonsInstantiated() { + if (!properties.isEnabled()) { + log.info("AgentScope native tracing is disabled by configuration."); + return; + } + + TracerRegistry.register(TelemetryTracer.builder().tracer(selectTracer()).build()); + if (properties.isUseLangfuseTracer() && openTelemetryConfig.isEnabled()) { + log.info( + "AgentScope native tracing initialized with local trace cache and Langfuse OpenTelemetry exporter."); + return; + } + + log.info("AgentScope native tracing initialized with local trace cache only."); + } + + private Tracer selectTracer() { + if (properties.isUseLangfuseTracer() && openTelemetryConfig.isEnabled()) { + return langfuseTracer; + } + return agentScopeLocalTracer; + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java index 828bb7ad8..44ad452fb 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.dataagent.config; import com.alibaba.cloud.ai.dataagent.constant.Constant; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; @@ -24,10 +25,13 @@ import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.resources.Resource; import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.trace.SdkTracerProviderBuilder; import io.opentelemetry.sdk.trace.export.BatchSpanProcessor; +import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; import jakarta.annotation.PreDestroy; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -57,50 +61,77 @@ public class OpenTelemetryConfig { private String secretKey; - private SdkTracerProvider tracerProvider; + private SdkTracerProvider langfuseTracerProvider; - @Bean - public OpenTelemetry openTelemetry() { - if (!enabled) { - return OpenTelemetry.noop(); + private SdkTracerProvider localTraceTracerProvider; + + @Bean("langfuseOpenTelemetry") + public OpenTelemetry langfuseOpenTelemetry(SessionTraceStore sessionTraceStore) { + langfuseTracerProvider = buildTracerProvider(sessionTraceStore, enabled); + OpenTelemetrySdk openTelemetrySdk = OpenTelemetrySdk.builder() + .setTracerProvider(langfuseTracerProvider) + .build(); + + if (enabled) { + log.info("OpenTelemetry initialized with local trace cache and Langfuse OTLP HTTP exporter."); + } + else { + log.info("OpenTelemetry initialized with local trace cache only."); } - String auth = publicKey + ":" + secretKey; - String encodedAuth = Base64.getEncoder().encodeToString(auth.getBytes(StandardCharsets.UTF_8)); + return openTelemetrySdk; + } - OtlpHttpSpanExporter spanExporter = OtlpHttpSpanExporter.builder() - .setEndpoint(host + "/api/public/otel/v1/traces") - .addHeader("Authorization", "Basic " + encodedAuth) - .setTimeout(10, TimeUnit.SECONDS) - .build(); + @Bean("agentScopeLocalOpenTelemetry") + public OpenTelemetry agentScopeLocalOpenTelemetry(SessionTraceStore sessionTraceStore) { + localTraceTracerProvider = buildTracerProvider(sessionTraceStore, false); + return OpenTelemetrySdk.builder().setTracerProvider(localTraceTracerProvider).build(); + } + private SdkTracerProvider buildTracerProvider(SessionTraceStore sessionTraceStore, boolean withLangfuseExporter) { Resource resource = Resource.getDefault() .merge(Resource.create(Attributes.of(AttributeKey.stringKey("service.name"), SERVICE_NAME))); - tracerProvider = SdkTracerProvider.builder() - .addSpanProcessor(BatchSpanProcessor.builder(spanExporter) - .setScheduleDelay(1, TimeUnit.SECONDS) - .setMaxExportBatchSize(100) - .build()) - .setResource(resource) - .build(); + SdkTracerProviderBuilder builder = SdkTracerProvider.builder() + .addSpanProcessor(SimpleSpanProcessor.create(sessionTraceStore)) + .setResource(resource); - OpenTelemetrySdk openTelemetrySdk = OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + if (withLangfuseExporter) { + String auth = publicKey + ":" + secretKey; + String encodedAuth = Base64.getEncoder().encodeToString(auth.getBytes(StandardCharsets.UTF_8)); - log.info("OpenTelemetry initialized with Langfuse OTLP HTTP exporter"); + OtlpHttpSpanExporter spanExporter = OtlpHttpSpanExporter.builder() + .setEndpoint(host + "/api/public/otel/v1/traces") + .addHeader("Authorization", "Basic " + encodedAuth) + .setTimeout(10, TimeUnit.SECONDS) + .build(); - return openTelemetrySdk; + builder.addSpanProcessor(BatchSpanProcessor.builder(spanExporter) + .setScheduleDelay(1, TimeUnit.SECONDS) + .setMaxExportBatchSize(100) + .build()); + } + + return builder.build(); } - @Bean - public Tracer langfuseTracer(OpenTelemetry openTelemetry) { + @Bean("langfuseTracer") + public Tracer langfuseTracer(@Qualifier("langfuseOpenTelemetry") OpenTelemetry openTelemetry) { return openTelemetry.getTracer(SERVICE_NAME); } + @Bean("agentScopeLocalTracer") + public Tracer agentScopeLocalTracer(@Qualifier("agentScopeLocalOpenTelemetry") OpenTelemetry openTelemetry) { + return openTelemetry.getTracer(SERVICE_NAME + ".agentscope.local"); + } + @PreDestroy public void shutdown() { - if (tracerProvider != null) { - tracerProvider.close(); + if (langfuseTracerProvider != null) { + langfuseTracerProvider.close(); + } + if (localTraceTracerProvider != null) { + localTraceTracerProvider.close(); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java index 2ef0ebe66..9f0f16ab6 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java @@ -29,6 +29,8 @@ import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import java.sql.Connection; import java.util.List; @@ -63,7 +65,8 @@ public T accessDb(DbConfigBO dbConfig, String method, DbQueryParameter param case "showColumns": return (T) ddlExecutor.showColumns(connection, param.getSchema(), param.getTable()); case "showForeignKeys": - return (T) ddlExecutor.showForeignKeys(connection, param.getSchema(), param.getTables()); + return (T) ddlExecutor.showForeignKeys(connection, param.getSchema(), + resolveTablesForForeignKeys(connection, ddlExecutor, param)); case "sampleColumn": return (T) ddlExecutor.sampleColumn(connection, param.getSchema(), param.getTable(), param.getColumn()); @@ -122,4 +125,19 @@ public Connection getConnection(DbConfigBO config) { return this.dbConnectionPool.getConnection(config); } + private List resolveTablesForForeignKeys(Connection connection, AbstractJdbcDdl ddlExecutor, + DbQueryParameter param) { + if (param == null) { + return List.of(); + } + if (!CollectionUtils.isEmpty(param.getTables())) { + return param.getTables(); + } + List tables = ddlExecutor.showTables(connection, param.getSchema(), param.getTablePattern()); + if (CollectionUtils.isEmpty(tables)) { + return List.of(); + } + return tables.stream().map(TableInfoBO::getName).filter(StringUtils::hasText).distinct().toList(); + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java index 82b931800..a8d70de17 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java @@ -19,6 +19,7 @@ import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; import com.alibaba.cloud.ai.dataagent.entity.ChatSession; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; import com.alibaba.cloud.ai.dataagent.service.chat.SessionTitleService; @@ -56,6 +57,8 @@ public class ChatController { private final ReportTemplateUtil reportTemplateUtil; + private final SessionTraceStore sessionTraceStore; + /** * Get session list for an agent */ @@ -118,6 +121,13 @@ public ResponseEntity> getSessionMessages(@PathVariable(value return ResponseEntity.ok(messages); } + @GetMapping("/sessions/{sessionId}/trace") + public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId") String sessionId) { + return sessionTraceStore.getLatestTrace(sessionId) + .>map(ResponseEntity::ok) + .orElseGet(() -> ResponseEntity.notFound().build()); + } + /** * Save message to session */ diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java new file mode 100644 index 000000000..96dfa13ae --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java @@ -0,0 +1,305 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.observability; + +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.sdk.common.CompletableResultCode; +import io.opentelemetry.sdk.trace.data.SpanData; +import io.opentelemetry.sdk.trace.export.SpanExporter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import lombok.Getter; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +/** + * Cache the latest completed trace for each chat session so the frontend can inspect + * the most recent AgentScope span chain without querying an external tracing backend. + */ +@Component +public class SessionTraceStore implements SpanExporter { + + public static final String ATTR_THREAD_ID = "dataagent.thread.id"; + + public static final String ATTR_RUNTIME_REQUEST_ID = "dataagent.runtime.request.id"; + + public static final String ATTR_AGENT_ID = "dataagent.agent.id"; + + private static final String ROOT_PARENT_SPAN_ID = "0000000000000000"; + + private static final int MAX_TRACE_ASSEMBLIES = 256; + + private static final int MAX_SESSION_TRACES = 128; + + private static final int MAX_ATTRIBUTE_VALUE_LENGTH = 256; + + private static final String META_OMITTED_ATTRIBUTE_COUNT = "_meta.omitted_attribute_count"; + + private static final Set SAFE_ATTRIBUTE_KEYS = Set.of(ATTR_THREAD_ID, ATTR_RUNTIME_REQUEST_ID, ATTR_AGENT_ID, + "dataagent.runtime.human_feedback", "dataagent.runtime.nl2sql_only", "data_agent.agent_id", + "data_agent.thread_id", "data_agent.nl2sql_only", "data_agent.human_feedback", + "gen_ai.usage.prompt_tokens", "gen_ai.usage.completion_tokens", "gen_ai.usage.total_tokens", "error.type"); + + private final Object monitor = new Object(); + + private final LinkedHashMap traceAssemblies = new LinkedHashMap<>(32, 0.75f, true); + + private final LinkedHashMap latestTraceBySessionId = new LinkedHashMap<>(32, 0.75f, true); + + @Override + public CompletableResultCode export(Collection spans) { + synchronized (monitor) { + for (SpanData span : spans) { + captureSpan(span); + } + } + return CompletableResultCode.ofSuccess(); + } + + @Override + public CompletableResultCode flush() { + return CompletableResultCode.ofSuccess(); + } + + @Override + public CompletableResultCode shutdown() { + synchronized (monitor) { + traceAssemblies.clear(); + latestTraceBySessionId.clear(); + } + return CompletableResultCode.ofSuccess(); + } + + public Optional getLatestTrace(String sessionId) { + synchronized (monitor) { + return Optional.ofNullable(latestTraceBySessionId.get(sessionId)); + } + } + + private void captureSpan(SpanData span) { + String traceId = span.getSpanContext().getTraceId(); + TraceAssembly assembly = traceAssemblies.computeIfAbsent(traceId, ignored -> new TraceAssembly()); + assembly.traceId = traceId; + assembly.spansBySpanId.put(span.getSpanContext().getSpanId(), toSpanView(span)); + + String sessionId = readAttribute(span, ATTR_THREAD_ID); + if (StringUtils.hasText(sessionId)) { + assembly.sessionId = sessionId; + assembly.runtimeRequestId = readAttribute(span, ATTR_RUNTIME_REQUEST_ID); + assembly.agentId = readAttribute(span, ATTR_AGENT_ID); + } + + if (StringUtils.hasText(assembly.sessionId)) { + latestTraceBySessionId.put(assembly.sessionId, assembly.toTraceView()); + evictOldSessionTraces(); + } + + evictOldAssemblies(); + } + + private String readAttribute(SpanData span, String key) { + String value = span.getAttributes().get(AttributeKey.stringKey(key)); + return StringUtils.hasText(value) ? value : null; + } + + private void evictOldAssemblies() { + while (traceAssemblies.size() > MAX_TRACE_ASSEMBLIES) { + String eldestTraceId = traceAssemblies.keySet().iterator().next(); + traceAssemblies.remove(eldestTraceId); + removeSessionTraceByTraceId(eldestTraceId); + } + } + + private void evictOldSessionTraces() { + while (latestTraceBySessionId.size() > MAX_SESSION_TRACES) { + String eldestSessionId = latestTraceBySessionId.keySet().iterator().next(); + TraceView removed = latestTraceBySessionId.remove(eldestSessionId); + if (removed != null) { + traceAssemblies.remove(removed.traceId()); + } + } + } + + private void removeSessionTraceByTraceId(String traceId) { + String matchedSessionId = null; + for (Map.Entry entry : latestTraceBySessionId.entrySet()) { + if (Objects.equals(entry.getValue().traceId(), traceId)) { + matchedSessionId = entry.getKey(); + break; + } + } + if (matchedSessionId != null) { + latestTraceBySessionId.remove(matchedSessionId); + } + } + + private SpanView toSpanView(SpanData span) { + Map attributes = sanitizeAttributes(span); + long startEpochNanos = span.getStartEpochNanos(); + long endEpochNanos = span.getEndEpochNanos(); + long durationMs = Math.max(0L, (endEpochNanos - startEpochNanos) / 1_000_000L); + long startEpochMs = startEpochNanos / 1_000_000L; + long endEpochMs = endEpochNanos / 1_000_000L; + return new SpanView(span.getName(), span.getSpanContext().getSpanId(), span.getParentSpanContext().getSpanId(), + span.getKind().name(), span.getStatus().getStatusCode().name(), startEpochMs, endEpochMs, durationMs, + attributes, List.of()); + } + + private Map sanitizeAttributes(SpanData span) { + Map attributes = new LinkedHashMap<>(); + int omittedAttributeCount = 0; + span.getAttributes().forEach((key, value) -> { + String attributeKey = key.getKey(); + if (!SAFE_ATTRIBUTE_KEYS.contains(attributeKey)) { + return; + } + attributes.put(attributeKey, sanitizeAttributeValue(String.valueOf(value))); + }); + omittedAttributeCount = Math.max(0, span.getTotalAttributeCount() - attributes.size()); + if (omittedAttributeCount > 0) { + attributes.put(META_OMITTED_ATTRIBUTE_COUNT, String.valueOf(omittedAttributeCount)); + } + return attributes; + } + + private String sanitizeAttributeValue(String value) { + if (!StringUtils.hasText(value)) { + return value; + } + String normalizedValue = value.replace('\r', ' ').replace('\n', ' ').trim(); + if (normalizedValue.length() <= MAX_ATTRIBUTE_VALUE_LENGTH) { + return normalizedValue; + } + return normalizedValue.substring(0, MAX_ATTRIBUTE_VALUE_LENGTH) + "..."; + } + + private static final class TraceAssembly { + + private String sessionId; + + private String traceId; + + private String runtimeRequestId; + + private String agentId; + + private final Map spansBySpanId = new LinkedHashMap<>(); + + private TraceView toTraceView() { + Map nodesById = new LinkedHashMap<>(); + for (SpanView span : spansBySpanId.values()) { + nodesById.put(span.getSpanId(), new MutableTreeNode(span)); + } + + List roots = new ArrayList<>(); + for (MutableTreeNode node : nodesById.values()) { + String parentSpanId = node.span.getParentSpanId(); + MutableTreeNode parent = nodesById.get(parentSpanId); + if (!StringUtils.hasText(parentSpanId) || ROOT_PARENT_SPAN_ID.equals(parentSpanId) || parent == null) { + roots.add(node); + } + else { + parent.children.add(node); + } + } + + roots.sort(Comparator.comparingLong(node -> node.span.getStartEpochMs())); + List rootViews = roots.stream().map(MutableTreeNode::toImmutable).toList(); + SpanView rootSpan = rootViews.isEmpty() ? null : rootViews.get(0); + long startedAt = spansBySpanId.values() + .stream() + .mapToLong(SpanView::getStartEpochMs) + .min() + .orElse(0L); + long endedAt = spansBySpanId.values().stream().mapToLong(SpanView::getEndEpochMs).max().orElse(0L); + long durationMs = Math.max(0L, endedAt - startedAt); + return new TraceView(sessionId, traceId, runtimeRequestId, agentId, startedAt, endedAt, durationMs, + spansBySpanId.size(), rootSpan, rootViews); + } + + } + + private static final class MutableTreeNode { + + private final SpanView span; + + private final List children = new ArrayList<>(); + + private MutableTreeNode(SpanView span) { + this.span = span; + } + + private SpanView toImmutable() { + children.sort(Comparator.comparingLong(node -> node.span.getStartEpochMs())); + return new SpanView(span.getName(), span.getSpanId(), span.getParentSpanId(), span.getKind(), span.getStatus(), + span.getStartEpochMs(), span.getEndEpochMs(), span.getDurationMs(), span.getAttributes(), + children.stream().map(MutableTreeNode::toImmutable).toList()); + } + + } + + public record TraceView(String sessionId, String traceId, String runtimeRequestId, String agentId, + long startEpochMs, long endEpochMs, long durationMs, int spanCount, SpanView rootSpan, + List rootSpans) { + } + + @Getter + public static final class SpanView { + + private final String name; + + private final String spanId; + + private final String parentSpanId; + + private final String kind; + + private final String status; + + private final long startEpochMs; + + private final long endEpochMs; + + private final long durationMs; + + private final Map attributes; + + private final List children; + + private SpanView(String name, String spanId, String parentSpanId, String kind, String status, long startEpochMs, + long endEpochMs, long durationMs, Map attributes, List children) { + this.name = name; + this.spanId = spanId; + this.parentSpanId = parentSpanId; + this.kind = kind; + this.status = status; + this.startEpochMs = startEpochMs; + this.endEpochMs = endEpochMs; + this.durationMs = durationMs; + this.attributes = attributes; + this.children = children; + } + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/properties/AgentScopeObservabilityProperties.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/properties/AgentScopeObservabilityProperties.java new file mode 100644 index 000000000..9ac50e391 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/properties/AgentScopeObservabilityProperties.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.properties; + +import com.alibaba.cloud.ai.dataagent.constant.Constant; +import lombok.Getter; +import lombok.Setter; +import org.springframework.boot.context.properties.ConfigurationProperties; + +@Getter +@Setter +@ConfigurationProperties(prefix = AgentScopeObservabilityProperties.CONFIG_PREFIX) +public class AgentScopeObservabilityProperties { + + public static final String CONFIG_PREFIX = Constant.PROJECT_PROPERTIES_PREFIX + ".agentscope.observability"; + + /** + * 是否启用 AgentScope 原生 tracing。 + */ + private boolean enabled = true; + + /** + * 是否优先复用现有 Langfuse OpenTelemetry tracer 作为导出通道。 + */ + private boolean useLangfuseTracer = true; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java index 29471c5ba..00f91cd34 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java @@ -23,6 +23,7 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; @@ -68,7 +69,8 @@ public class LangfuseService { // --- Token 累计器,按 threadId 隔离 --- private static final ConcurrentHashMap TOKEN_ACCUMULATOR = new ConcurrentHashMap<>(); - public LangfuseService(Tracer langfuseTracer, @Value("${langfuse.enabled:true}") boolean enabled) { + public LangfuseService(@Qualifier("langfuseTracer") Tracer langfuseTracer, + @Value("${langfuse.enabled:true}") boolean enabled) { this.tracer = langfuseTracer; this.enabled = enabled; } diff --git a/data-agent-management/src/main/resources/application.yml b/data-agent-management/src/main/resources/application.yml index ef4520de4..2466aae4e 100644 --- a/data-agent-management/src/main/resources/application.yml +++ b/data-agent-management/src/main/resources/application.yml @@ -61,6 +61,10 @@ spring: host: ${LANGFUSE_HOST:} public-key: ${LANGFUSE_PUBLIC_KEY:} secret-key: ${LANGFUSE_SECRET_KEY:} + agentscope: + observability: + enabled: ${AGENTSCOPE_OBSERVABILITY_ENABLED:true} + use-langfuse-tracer: ${AGENTSCOPE_OBSERVABILITY_USE_LANGFUSE_TRACER:true} webflux: multipart: max-file-size: 10MB diff --git a/pom.xml b/pom.xml index 5e5421262..86fae1148 100644 --- a/pom.xml +++ b/pom.xml @@ -52,22 +52,20 @@ 3.1.0 3.5.3 - 3.0.0 1.2.22 42.4.1 3.2.1 2.18.0 - 0.17.0 5.4.1 1.27.1 4.1.114.Final 4.1.114.Final - 4.37.0 2.3.232 8.1.3.140 1.21.4 8.18.0 1.32.0 + 2.24.0-alpha 4.2.2 0.8.12 @@ -100,11 +98,6 @@ pom import - - com.aliyun - gpdb20160503 - ${gpdb.version} - com.alibaba druid @@ -148,11 +141,6 @@ commons-io ${commons-io.version} - - com.atlassian.commonmark - commonmark - ${commonmark.version} - org.apache.httpcomponents.client5 httpclient5 @@ -163,16 +151,6 @@ commons-compress ${commons-compress.version} - - com.github.victools - jsonschema-generator - ${jsonschema.version} - - - com.github.victools - jsonschema-module-jackson - ${jsonschema.version} - io.netty netty-resolver-dns-native-macos From 74859421fd79060742a1b3a7f64795b4dfae0f67 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Thu, 23 Apr 2026 22:30:22 +0800 Subject: [PATCH 05/22] feat: add sql visibility --- .../src/components/agent/DataSourceConfig.vue | 440 ++++++++++- .../src/services/agentDatasource.ts | 121 +-- .../src/services/datasource.ts | 18 + .../datasource/DatasourceExplorerService.java | 739 ++++++++++++++++-- .../DatasourceExplorerToolProvider.java | 3 +- .../controller/AgentDatasourceController.java | 147 ++-- .../controller/GlobalExceptionHandler.java | 35 +- .../dto/datasource/SchemaInitRequest.java | 19 +- .../datasource/TableColumnsSelectionDTO.java | 30 + .../UpdateDatasourceColumnsDTO.java | 32 + .../ai/dataagent/entity/AgentDatasource.java | 4 + .../entity/AgentDatasourceColumn.java | 46 ++ .../mapper/AgentDatasourceColumnsMapper.java | 54 ++ .../datasource/AgentDatasourceService.java | 8 +- .../impl/AgentDatasourceServiceImpl.java | 279 ++++++- .../service/schema/SchemaServiceImpl.java | 121 +++ .../cloud/ai/dataagent/util/SqlUtil.java | 53 +- .../src/main/resources/sql/h2/schema-h2.sql | 21 +- .../src/main/resources/sql/schema.sql | 22 +- .../src/test/resources/sql/schema.sql | 2 +- 20 files changed, 1948 insertions(+), 246 deletions(-) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java diff --git a/data-agent-frontend/src/components/agent/DataSourceConfig.vue b/data-agent-frontend/src/components/agent/DataSourceConfig.vue index da448cde8..83c65a980 100644 --- a/data-agent-frontend/src/components/agent/DataSourceConfig.vue +++ b/data-agent-frontend/src/components/agent/DataSourceConfig.vue @@ -1,4 +1,4 @@ - > = ref({}); const updateLoadingStates: Ref> = ref({}); const agentDatasourceList: Ref = ref([]); + const selectedColumns: Ref>> = ref({}); + const columnOptionsByDatasource: Ref>> = ref({}); + const columnRestrictionEnabled: Ref>> = ref({}); + const columnLoadingStates: Ref> = ref({}); + const columnDialogVisible: Ref = ref(false); + const currentColumnDatasource: Ref = ref(null); + const currentColumnTables: Ref = ref([]); + const savingColumnVisibility: Ref = ref(false); // 逻辑外键管理相关状态 const foreignKeyDialogVisible: Ref = ref(false); @@ -892,9 +1027,20 @@ const datasourceItem = { ...item.datasource }; datasourceItem.status = item.isActive === 1 ? 'active' : 'inactive'; - // 初始化已选择的表 - if (item.selectTables && item.datasource?.id) { - selectedTables.value[item.datasource.id] = [...item.selectTables]; + if (item.datasource?.id) { + if (item.selectTables) { + selectedTables.value[item.datasource.id] = [...item.selectTables]; + } + selectedColumns.value[item.datasource.id] = Object.entries(item.selectColumns || {}).reduce< + Record + >((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}); + columnRestrictionEnabled.value[item.datasource.id] = {}; + Object.keys(selectedColumns.value[item.datasource.id]).forEach(tableName => { + columnRestrictionEnabled.value[item.datasource.id][tableName] = true; + }); } return datasourceItem; @@ -905,6 +1051,136 @@ } }; + const getAgentDatasourceByDatasourceId = (datasourceId: number): AgentDatasource | undefined => { + return agentDatasourceList.value.find(item => item.datasource?.id === datasourceId); + }; + + const getErrorMessage = (error: unknown, fallback: string): string => { + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; + }; + + const applyAgentDatasourceSnapshot = (snapshot: AgentDatasource): void => { + const datasourceId = snapshot.datasource?.id; + if (!datasourceId || !snapshot.datasource) { + return; + } + + const nextSnapshot: AgentDatasource = { + ...snapshot, + selectTables: [...(snapshot.selectTables || [])], + selectColumns: Object.entries(snapshot.selectColumns || {}).reduce>( + (result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, + {}, + ), + }; + + const agentDatasourceIndex = agentDatasourceList.value.findIndex( + item => item.datasource?.id === datasourceId, + ); + if (agentDatasourceIndex >= 0) { + agentDatasourceList.value[agentDatasourceIndex] = nextSnapshot; + } else { + agentDatasourceList.value.push(nextSnapshot); + } + + const datasourceSnapshot: Datasource = { + ...nextSnapshot.datasource, + status: nextSnapshot.isActive === 1 ? 'active' : 'inactive', + }; + const datasourceIndex = datasource.value.findIndex(item => item.id === datasourceId); + if (datasourceIndex >= 0) { + datasource.value[datasourceIndex] = datasourceSnapshot; + } else { + datasource.value.push(datasourceSnapshot); + } + + selectedTables.value[datasourceId] = [...(nextSnapshot.selectTables || [])]; + selectedColumns.value[datasourceId] = Object.entries(nextSnapshot.selectColumns || {}).reduce< + Record + >((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}); + columnRestrictionEnabled.value[datasourceId] = {}; + (nextSnapshot.selectTables || []).forEach(tableName => { + columnRestrictionEnabled.value[datasourceId][tableName] = + (nextSnapshot.selectColumns?.[tableName] || []).length > 0; + }); + }; + + const getSelectedTablesForDatasource = (datasourceId: number): string[] => { + const currentTables = selectedTables.value[datasourceId]; + if (currentTables && currentTables.length > 0) { + return [...currentTables]; + } + return [...(getAgentDatasourceByDatasourceId(datasourceId)?.selectTables || [])]; + }; + + const normalizeNameList = (values: string[] = []): string[] => { + return [...values] + .map(value => value.trim()) + .filter(Boolean) + .sort((left, right) => left.localeCompare(right)); + }; + + const hasPendingTableChanges = (datasourceRow: Datasource): boolean => { + if (!datasourceRow.id) { + return false; + } + const savedTables = normalizeNameList( + getAgentDatasourceByDatasourceId(datasourceRow.id)?.selectTables || [], + ); + const currentTables = normalizeNameList(selectedTables.value[datasourceRow.id] || []); + return savedTables.join('|') !== currentTables.join('|'); + }; + + const resolveConfiguredColumns = ( + selectColumns: Record | undefined, + tableName: string, + ): string[] => { + if (!selectColumns) { + return []; + } + if (selectColumns[tableName]) { + return [...selectColumns[tableName]]; + } + const matchedKey = Object.keys(selectColumns).find( + key => key.toLowerCase() === tableName.toLowerCase(), + ); + return matchedKey ? [...(selectColumns[matchedKey] || [])] : []; + }; + + const getColumnLoadingKey = (datasourceId: number, tableName: string): string => { + return `${datasourceId}:${tableName}`; + }; + + const loadColumnsForTable = async (datasourceId: number, tableName: string): Promise => { + const loadingKey = getColumnLoadingKey(datasourceId, tableName); + columnLoadingStates.value[loadingKey] = true; + try { + const columns = await agentDatasourceService.getVisibleTableColumns( + String(props.agentId), + datasourceId, + tableName, + ); + if (!columnOptionsByDatasource.value[datasourceId]) { + columnOptionsByDatasource.value[datasourceId] = {}; + } + columnOptionsByDatasource.value[datasourceId][tableName] = columns; + } catch (error) { + ElMessage.error(getErrorMessage(error, `加载表 ${tableName} 的字段失败`)); + console.error('Failed to load datasource columns:', error); + } finally { + columnLoadingStates.value[loadingKey] = false; + } + }; + const handleSelectDatasourceChange = (value: Datasource) => { if (value === null || value === undefined) { selectedDatasourceId.value = null; @@ -1274,34 +1550,142 @@ updateLoadingStates.value[datasource.id] = true; try { - const response = await agentDatasourceService.updateDatasourceTables( - String(props.agentId), - { - datasourceId: datasource.id, - tables: selectedTables.value[datasource.id] || [], - }, - ); + const response = await agentDatasourceService.updateDatasourceTables(String(props.agentId), { + datasourceId: datasource.id, + tables: selectedTables.value[datasource.id] || [], + }); - if (response.success) { + if (response.success && response.data) { + applyAgentDatasourceSnapshot(response.data); ElMessage.success('数据表更新成功'); - // 更新本地存储的已选择表 - const agentDatasource = agentDatasourceList.value.find( - item => item.datasource?.id === datasource.id, - ); - if (agentDatasource) { - agentDatasource.selectTables = [...(selectedTables.value[datasource.id] || [])]; - } } else { - ElMessage.error('数据表更新失败'); + ElMessage.error(response.message || '数据表更新失败'); } } catch (error) { - ElMessage.error('数据表更新失败'); + ElMessage.error(getErrorMessage(error, '数据表更新失败')); console.error('Failed to update datasource tables:', error); } finally { updateLoadingStates.value[datasource.id] = false; } }; + const toggleColumnRestriction = (tableName: string, enabled: boolean | string | number) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + if (!columnRestrictionEnabled.value[datasourceId]) { + columnRestrictionEnabled.value[datasourceId] = {}; + } + columnRestrictionEnabled.value[datasourceId][tableName] = Boolean(enabled); + if (!enabled) { + selectedColumns.value[datasourceId][tableName] = []; + } + }; + + const openColumnVisibilityDialog = async (datasourceRow: Datasource) => { + if (!datasourceRow.id) { + return; + } + if (hasPendingTableChanges(datasourceRow)) { + ElMessage.warning('请先点击“更新数据表”保存当前表配置,再设置字段可见性'); + return; + } + + const tables = getSelectedTablesForDatasource(datasourceRow.id); + if (tables.length === 0) { + ElMessage.warning('请先选择并保存数据表,再配置字段可见性'); + return; + } + + currentColumnDatasource.value = datasourceRow; + currentColumnTables.value = [...tables]; + + if (!selectedColumns.value[datasourceRow.id]) { + selectedColumns.value[datasourceRow.id] = {}; + } + if (!columnRestrictionEnabled.value[datasourceRow.id]) { + columnRestrictionEnabled.value[datasourceRow.id] = {}; + } + if (!columnOptionsByDatasource.value[datasourceRow.id]) { + columnOptionsByDatasource.value[datasourceRow.id] = {}; + } + + const agentDatasource = getAgentDatasourceByDatasourceId(datasourceRow.id); + tables.forEach(tableName => { + const configuredColumns = resolveConfiguredColumns(agentDatasource?.selectColumns, tableName); + selectedColumns.value[datasourceRow.id][tableName] = configuredColumns; + columnRestrictionEnabled.value[datasourceRow.id][tableName] = configuredColumns.length > 0; + }); + + await Promise.all(tables.map(tableName => loadColumnsForTable(datasourceRow.id!, tableName))); + columnDialogVisible.value = true; + }; + + const selectAllColumnsForTable = (tableName: string) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + selectedColumns.value[datasourceId][tableName] = [ + ...(columnOptionsByDatasource.value[datasourceId]?.[tableName] || []), + ]; + }; + + const clearColumnsForTable = (tableName: string) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + selectedColumns.value[datasourceId][tableName] = []; + }; + + const saveDatasourceColumns = async () => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + + const invalidTables = currentColumnTables.value.filter(tableName => { + return ( + columnRestrictionEnabled.value[datasourceId]?.[tableName] && + !(selectedColumns.value[datasourceId]?.[tableName] || []).length + ); + }); + if (invalidTables.length > 0) { + ElMessage.warning(`请至少为以下数据表选择一个字段:${invalidTables.join('、')}`); + return; + } + + savingColumnVisibility.value = true; + try { + const tables = currentColumnTables.value + .filter(tableName => columnRestrictionEnabled.value[datasourceId]?.[tableName]) + .map(tableName => ({ + tableName, + columns: [...(selectedColumns.value[datasourceId]?.[tableName] || [])], + })); + + const response = await agentDatasourceService.updateDatasourceColumns(String(props.agentId), { + datasourceId, + tables, + }); + + if (response.success && response.data) { + applyAgentDatasourceSnapshot(response.data); + ElMessage.success('字段可见性更新成功'); + columnDialogVisible.value = false; + } else { + ElMessage.error(response.message || '字段可见性更新失败'); + } + } catch (error) { + ElMessage.error(getErrorMessage(error, '字段可见性更新失败')); + console.error('Failed to update datasource columns:', error); + } finally { + savingColumnVisibility.value = false; + } + }; + // 全选表 const selectAllTables = (datasource: Datasource) => { if (!datasource.id || !tableLists.value[datasource.id]) return; @@ -1590,6 +1974,14 @@ editingDatasource, tableLists, selectedTables, + selectedColumns, + columnOptionsByDatasource, + columnRestrictionEnabled, + columnLoadingStates, + columnDialogVisible, + currentColumnDatasource, + currentColumnTables, + savingColumnVisibility, tableLoadingStates, updateLoadingStates, initAgentDatasource, @@ -1605,6 +1997,12 @@ deleteDatasource, loadDatasourceTables, updateDatasourceTables, + openColumnVisibilityDialog, + saveDatasourceColumns, + selectAllColumnsForTable, + clearColumnsForTable, + toggleColumnRestriction, + getColumnLoadingKey, selectAllTables, clearAllTables, truncateText, diff --git a/data-agent-frontend/src/services/agentDatasource.ts b/data-agent-frontend/src/services/agentDatasource.ts index bf49b3624..2137c9fed 100644 --- a/data-agent-frontend/src/services/agentDatasource.ts +++ b/data-agent-frontend/src/services/agentDatasource.ts @@ -28,26 +28,44 @@ interface UpdateDatasourceTablesDto { tables?: string[]; } +interface TableColumnsSelectionDto { + tableName: string; + columns?: string[]; +} + +interface UpdateDatasourceColumnsDto { + datasourceId?: number; + tables?: TableColumnsSelectionDto[]; +} + const BASE_URL_FUNC = (agentId: string) => `/api/agent/${agentId}/datasources`; +const extractApiErrorMessage = (error: unknown, fallback: string): string => { + if (axios.isAxiosError(error)) { + const responseMessage = error.response?.data?.message; + if (typeof responseMessage === 'string' && responseMessage.trim()) { + return responseMessage; + } + if (typeof error.message === 'string' && error.message.trim()) { + return error.message; + } + } + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; +}; + class AgentDatasourceService { - /** - * 初始化数据源Schema - * @param agentId 智能体ID - */ async initSchema(agentId: string): Promise> { try { const response = await axios.post>(`${BASE_URL_FUNC(agentId)}/init`); return response.data; } catch (error) { - throw new Error(`初始化Schema失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '初始化 Schema 失败')); } } - /** - * 获取智能体的数据源列表 - * @param agentId 智能体ID - */ async getAgentDatasource(agentId: number): Promise { try { const response = await axios.get>( @@ -58,36 +76,24 @@ class AgentDatasourceService { } throw new Error(response.data.message); } catch (error) { - throw new Error(`获取数据源列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '获取数据源列表失败')); } } - /** - * 获取当前激活的智能体 - * @param agentId 智能体ID - */ async getActiveAgentDatasource(agentId: number): Promise { try { const response = await axios.get>( - BASE_URL_FUNC(String(agentId)) + '/active', + `${BASE_URL_FUNC(String(agentId))}/active`, ); - if (response.data.success) { - if (response.data.data === undefined) { - throw new Error('后端错误'); - } + if (response.data.success && response.data.data) { return response.data.data; } - throw new Error(response.data.message); + throw new Error(response.data.message || '后端返回了空数据'); } catch (error) { - throw new Error(`获取数据源列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '获取当前启用数据源失败')); } } - /** - * 为智能体添加数据源 - * @param agentId 智能体ID - * @param datasourceId 数据源ID - */ async addDatasourceToAgent( agentId: string, datasourceId: number, @@ -98,15 +104,10 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`添加数据源失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '添加数据源失败')); } } - /** - * 从智能体移除数据源 - * @param agentId 智能体ID - * @param datasourceId 数据源ID - */ async removeDatasourceFromAgent( agentId: string, datasourceId: number, @@ -117,15 +118,10 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`移除数据源失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '移除数据源失败')); } } - /** - * 启用/禁用智能体的数据源 - * @param agentId 智能体ID - * @param dto 切换参数 - */ async toggleDatasourceForAgent( agentId: string, dto: ToggleDatasourceDto, @@ -137,24 +133,55 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`切换数据源状态失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '切换数据源状态失败')); } } - /** - * 更新数据源的表列表 - * @param agentId 智能体ID - * @param dto 更新参数 - */ async updateDatasourceTables( agentId: string, dto: UpdateDatasourceTablesDto, - ): Promise> { + ): Promise> { + try { + const response = await axios.post>( + `${BASE_URL_FUNC(agentId)}/tables`, + dto, + ); + return response.data; + } catch (error) { + throw new Error(extractApiErrorMessage(error, '更新数据表列表失败')); + } + } + + async updateDatasourceColumns( + agentId: string, + dto: UpdateDatasourceColumnsDto, + ): Promise> { try { - const response = await axios.post>(`${BASE_URL_FUNC(agentId)}/tables`, dto); + const response = await axios.post>( + `${BASE_URL_FUNC(agentId)}/columns`, + dto, + ); return response.data; } catch (error) { - throw new Error(`更新数据源表列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '更新字段可见性失败')); + } + } + + async getVisibleTableColumns( + agentId: string, + datasourceId: number, + tableName: string, + ): Promise { + try { + const response = await axios.get>( + `${BASE_URL_FUNC(agentId)}/${datasourceId}/tables/${encodeURIComponent(tableName)}/columns`, + ); + if (response.data.success) { + return response.data.data || []; + } + throw new Error(response.data.message); + } catch (error) { + throw new Error(extractApiErrorMessage(error, `加载表 ${tableName} 的字段失败`)); } } } diff --git a/data-agent-frontend/src/services/datasource.ts b/data-agent-frontend/src/services/datasource.ts index 184858fd2..d0b9f0d6c 100644 --- a/data-agent-frontend/src/services/datasource.ts +++ b/data-agent-frontend/src/services/datasource.ts @@ -45,6 +45,7 @@ export interface AgentDatasource { updateTime?: string; datasource?: Datasource; selectTables?: string[]; + selectColumns?: Record; } // 定义数据源类型接口 @@ -97,6 +98,23 @@ class DatasourceService { } } + async getTableColumns(id: number, tableName: string): Promise { + try { + const response = await axios.get>( + `${API_BASE_URL}/${id}/tables/${encodeURIComponent(tableName)}/columns`, + ); + if (response.data.success) { + return response.data.data || []; + } + throw new Error(response.data.message); + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 400) { + return []; + } + throw error; + } + } + // 4. 创建数据源 async createDatasource(datasource: Datasource): Promise { const response = await axios.post(API_BASE_URL, datasource); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index 099b9d1b9..47a03dec9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -40,14 +40,33 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.regex.Pattern; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.AllColumns; +import net.sf.jsqlparser.statement.select.AllTableColumns; +import net.sf.jsqlparser.statement.select.FromItem; +import net.sf.jsqlparser.statement.select.Join; +import net.sf.jsqlparser.statement.select.LateralSubSelect; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.ParenthesedFromItem; +import net.sf.jsqlparser.statement.select.ParenthesedSelect; +import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; -import net.sf.jsqlparser.util.TablesNamesFinder; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.statement.select.TableFunction; +import net.sf.jsqlparser.statement.select.WithItem; import org.apache.commons.lang3.StringUtils; import org.springframework.ai.document.Document; import org.springframework.stereotype.Service; @@ -70,6 +89,9 @@ public class DatasourceExplorerService { private static final TypeReference> STRING_LIST_TYPE = new TypeReference<>() { }; + private static final String HIDDEN_FIELD_INFERENCE_WARNING = + " Answer strictly from returned columns only. Never infer hidden fields from visible values such as email local parts, IDs, codes, or aliases."; + private final AgentDatasourceService agentDatasourceService; private final DatasourceService datasourceService; @@ -101,9 +123,8 @@ private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceE List> tables = context.visibleTables() .stream() .sorted(String.CASE_INSENSITIVE_ORDER) - .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables(), - context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()))) + .map(tableName -> toTableEntry(context, tableName, tableDocumentMap.get(normalizeTableName(tableName)), + filterRelations(context, tableName))) .limit(limit) .toList(); return baseResult(context, DatasourceExplorerAction.LIST_TABLES, tables.size() + " tables available") @@ -119,9 +140,8 @@ private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceE Map tableDocumentMap = loadTableDocumentMap(context, context.visibleTables()); List> matchedTables = context.visibleTables() .stream() - .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables(), - context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()))) + .map(tableName -> toTableEntry(context, tableName, tableDocumentMap.get(normalizeTableName(tableName)), + filterRelations(context, tableName))) .filter(table -> query.isEmpty() || containsQuery(table, query)) .limit(limit) .toList(); @@ -135,21 +155,20 @@ private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceE private DatasourceExplorerResult getTableSchema(ExplorerContext context, DatasourceExplorerRequest request) throws Exception { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); List columns = context.accessor() .showColumns(context.dbConfig(), DbQueryParameter.from(context.dbConfig()) .setSchema(context.dbConfig().getSchema()) .setTable(tableName)); Document tableDocument = loadTableDocumentMap(context, List.of(tableName)).get(normalizeTableName(tableName)); - List> columnEntries = columns.stream().map(this::toColumnEntry).toList(); - List relations = filterRelations(context, tableName); - List> relationEntries = relations.stream() - .map(this::toRelationEntry) + Map columnDocumentMap = loadColumnDocumentMap(context, tableName); + List> columnEntries = applyVisibleColumnFilter(context, tableName, columns).stream() + .map(column -> toColumnEntry(column, columnDocumentMap.get(normalizeColumnName(column.getName())))) .toList(); - Map tableEntry = toTableEntry(tableName, tableDocument, context.explicitSelectedTables(), - relations); + List relations = filterRelations(context, tableName); + List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); + Map tableEntry = toTableEntry(context, tableName, tableDocument, relations); return baseResult(context, DatasourceExplorerAction.GET_TABLE_SCHEMA, "Loaded schema for table '%s'".formatted(tableName)) .tables(List.of(tableEntry)) @@ -160,8 +179,7 @@ private DatasourceExplorerResult getTableSchema(ExplorerContext context, Datasou } private DatasourceExplorerResult getRelatedTables(ExplorerContext context, DatasourceExplorerRequest request) { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); List relations = filterRelations(context, tableName); List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); Set relatedTables = relations.stream() @@ -171,9 +189,8 @@ private DatasourceExplorerResult getRelatedTables(ExplorerContext context, Datas .collect(Collectors.toCollection(LinkedHashSet::new)); Map tableDocumentMap = loadTableDocumentMap(context, new ArrayList<>(relatedTables)); List> tableEntries = relatedTables.stream() - .map(relatedTable -> toTableEntry(relatedTable, tableDocumentMap.get(normalizeTableName(relatedTable)), - context.explicitSelectedTables(), - context.relationsByTable().getOrDefault(normalizeTableName(relatedTable), List.of()))) + .map(relatedTable -> toTableEntry(context, relatedTable, tableDocumentMap.get(normalizeTableName(relatedTable)), + filterRelations(context, relatedTable))) .toList(); return baseResult(context, DatasourceExplorerAction.GET_RELATED_TABLES, "Found %d related tables for '%s'".formatted(tableEntries.size(), tableName)) @@ -185,13 +202,15 @@ private DatasourceExplorerResult getRelatedTables(ExplorerContext context, Datas private DatasourceExplorerResult previewRows(ExplorerContext context, DatasourceExplorerRequest request) throws Exception { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); int limit = normalizeLimit(request.getLimit()); - String sql = SqlUtil.buildSelectSql(context.dbConfig().getDialectType(), tableName, "*", limit); + String sql = SqlUtil.buildSelectSql(context.dbConfig().getDialectType(), + SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), tableName), + resolvePreviewColumnSelection(context, tableName), limit); ResultSetBO resultSet = executeSql(context, sql); return baseResult(context, DatasourceExplorerAction.PREVIEW_ROWS, - "Previewed %d rows from '%s'".formatted(resultSet.getData().size(), tableName)) + ("Previewed %d rows from '%s'".formatted(resultSet.getData().size(), tableName)) + + HIDDEN_FIELD_INFERENCE_WARNING) .tables(List.of(Map.of("name", tableName))) .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) @@ -208,13 +227,14 @@ private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplo throw new IllegalArgumentException("search action 必须提供 sql"); } int limit = normalizeLimit(request.getLimit()); - String guardedSql = guardReadonlySql(context, rawSql, limit); - ResultSetBO resultSet = executeSql(context, guardedSql); + SqlGuardedQuery guardedQuery = guardReadonlySql(context, rawSql, limit); + ResultSetBO resultSet = filterResultSet(executeSql(context, guardedQuery.sql()), guardedQuery); return baseResult(context, DatasourceExplorerAction.SEARCH, - "Executed readonly search and returned %d rows".formatted(resultSet.getData().size())) + ("Executed readonly search and returned %d rows".formatted(resultSet.getData().size())) + + HIDDEN_FIELD_INFERENCE_WARNING) .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) - .sql(guardedSql) + .sql(guardedQuery.sql()) .nextSuggestedActions(List.of("get_table_schema", "preview_rows", "find_tables")) .truncated(resultSet.getData().size() >= limit) .build(); @@ -234,16 +254,31 @@ private ExplorerContext resolveContext(String agentId) throws Exception { : agentDatasource.getSelectTables(); List visibleTables = explicitSelectedTables.isEmpty() ? datasourceService.getDatasourceTables(datasource.getId()) : explicitSelectedTables; + Map> visibleTablesByName = indexTablesByIdentity(visibleTables); + Map> visibleTablesByLeafName = indexTablesByLeafName(visibleTables); Set visibleTableNameSet = visibleTables.stream() .map(this::normalizeTableName) .collect(Collectors.toCollection(LinkedHashSet::new)); + Map> visibleColumnsByTable = buildVisibleColumnsByTable(agentDatasource, visibleTablesByName, + visibleTablesByLeafName); + Map> visibleColumnNameSetByTable = visibleColumnsByTable.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, + entry -> entry.getValue() + .stream() + .map(this::normalizeColumnName) + .collect(Collectors.toCollection(LinkedHashSet::new)), + (left, right) -> left, LinkedHashMap::new)); List logicalRelations = datasourceService.getLogicalRelations(datasource.getId()); List physicalRelations = loadPhysicalRelations(accessor, dbConfig, visibleTables); - List unifiedRelations = buildUnifiedRelations(visibleTableNameSet, physicalRelations, - logicalRelations == null ? List.of() : logicalRelations); + List unifiedRelations = buildUnifiedRelations(visibleTablesByName, visibleTablesByLeafName, + physicalRelations, logicalRelations == null ? List.of() : logicalRelations); return new ExplorerContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), - Set.copyOf(visibleTableNameSet), List.copyOf(explicitSelectedTables), - List.copyOf(unifiedRelations), indexRelationsByTable(unifiedRelations)); + Set.copyOf(visibleTableNameSet), toImmutableListIndex(visibleTablesByName), + toImmutableListIndex(visibleTablesByLeafName), List.copyOf(explicitSelectedTables), + Map.copyOf(visibleColumnsByTable), toImmutableSetIndex(visibleColumnNameSetByTable), + Set.copyOf(visibleColumnsByTable.keySet()), List.copyOf(unifiedRelations), + indexRelationsByTable(unifiedRelations)); } private List loadPhysicalRelations(Accessor accessor, DbConfigBO dbConfig, List tables) { @@ -283,25 +318,16 @@ private ResultSetBO executeSql(ExplorerContext context, String sql) throws Excep return resultSet; } - private String guardReadonlySql(ExplorerContext context, String rawSql, int limit) { + private SqlGuardedQuery guardReadonlySql(ExplorerContext context, String rawSql, int limit) { String compactSql = stripTrailingSemicolons(rawSql); if (compactSql.isEmpty()) { throw new IllegalArgumentException("SQL 不能为空"); } Statement statement = parseSingleSelectStatement(compactSql); - Set referencedTables = extractReferencedTables(statement); - if (!referencedTables.isEmpty()) { - List forbiddenTables = referencedTables.stream() - .filter(table -> !context.visibleTableNameSet().contains(normalizeTableName(table))) - .toList(); - if (!forbiddenTables.isEmpty()) { - throw new IllegalArgumentException("SQL 引用了当前 Agent 不可见的表: " + forbiddenTables); - } - } - if (hasLimit(compactSql)) { - return compactSql; - } - return wrapLimitSql(context.dbConfig().getDialectType(), compactSql, limit); + SqlValidationResult validationResult = new SqlColumnAccessValidator(context).validate((Select) statement); + String guardedSql = hasLimit(compactSql) ? compactSql + : wrapLimitSql(context.dbConfig().getDialectType(), compactSql, limit); + return new SqlGuardedQuery(guardedSql, validationResult.referencedTables(), validationResult.allowedResultHeaders()); } private String stripTrailingSemicolons(String sql) { @@ -338,18 +364,6 @@ private Statement parseSingleSelectStatement(String sql) { return statement; } - private Set extractReferencedTables(Statement statement) { - try { - return new TablesNamesFinder().getTableList(statement) - .stream() - .map(this::normalizeTableName) - .collect(Collectors.toCollection(LinkedHashSet::new)); - } - catch (Exception ex) { - throw new IllegalArgumentException("无法从 SQL 中提取引用表", ex); - } - } - private String wrapLimitSql(String dialectType, String sql, int limit) { String normalizedDialect = StringUtils.defaultString(dialectType).toLowerCase(Locale.ROOT); if (normalizedDialect.contains("sqlserver") || normalizedDialect.contains("sql_server")) { @@ -361,12 +375,6 @@ private String wrapLimitSql(String dialectType, String sql, int limit) { return "SELECT * FROM (%s) dataagent_safe_limit LIMIT %d".formatted(sql, limit); } - private void assertVisibleTable(ExplorerContext context, String tableName) { - if (!context.visibleTableNameSet().contains(normalizeTableName(tableName))) { - throw new IllegalArgumentException("Table '%s' is not visible for current agent".formatted(tableName)); - } - } - private String requireSingleTableName(DatasourceExplorerRequest request) { if (StringUtils.isNotBlank(request.getTableName())) { return request.getTableName().trim(); @@ -399,12 +407,23 @@ private Map loadTableDocumentMap(ExplorerContext context, List } } - private Map toTableEntry(String tableName, Document tableDocument, - List explicitSelectedTables, List relations) { + private Map loadColumnDocumentMap(ExplorerContext context, String tableName) { + try { + return schemaService.getColumnDocumentsByTableName(context.datasource().getId(), List.of(tableName)) + .stream() + .collect(Collectors.toMap(doc -> normalizeColumnName(String.valueOf(doc.getMetadata().get("name"))), + doc -> doc, (left, right) -> left, LinkedHashMap::new)); + } + catch (Exception ex) { + return Collections.emptyMap(); + } + } + + private Map toTableEntry(ExplorerContext context, String tableName, Document tableDocument, + List relations) { Map tableEntry = new LinkedHashMap<>(); tableEntry.put("name", tableName); - tableEntry.put("selected", explicitSelectedTables.isEmpty() || explicitSelectedTables.stream() - .anyMatch(candidate -> normalizeTableName(candidate).equals(normalizeTableName(tableName)))); + tableEntry.put("selected", isSelectedTable(context, tableName)); String unifiedForeignKeys = summarizeRelations(relations); if (tableDocument != null) { tableEntry.put("schema", tableDocument.getMetadata().getOrDefault("schema", "")); @@ -420,17 +439,37 @@ else if (StringUtils.isNotBlank(unifiedForeignKeys)) { return tableEntry; } - private Map toColumnEntry(ColumnInfoBO columnInfo) { + private Map toColumnEntry(ColumnInfoBO columnInfo, Document columnDocument) { Map columnEntry = new LinkedHashMap<>(); columnEntry.put("name", columnInfo.getName()); columnEntry.put("type", columnInfo.getType()); - columnEntry.put("description", StringUtils.defaultString(columnInfo.getDescription())); + columnEntry.put("description", resolveColumnDescription(columnInfo, columnDocument)); columnEntry.put("primary", columnInfo.isPrimary()); columnEntry.put("notnull", columnInfo.isNotnull()); - columnEntry.put("samples", parseSamples(columnInfo.getSamples())); + columnEntry.put("samples", resolveColumnSamples(columnInfo, columnDocument)); return columnEntry; } + private String resolveColumnDescription(ColumnInfoBO columnInfo, Document columnDocument) { + if (StringUtils.isNotBlank(columnInfo.getDescription())) { + return columnInfo.getDescription(); + } + if (columnDocument == null) { + return StringUtils.EMPTY; + } + return String.valueOf(columnDocument.getMetadata().getOrDefault("description", "")); + } + + private List resolveColumnSamples(ColumnInfoBO columnInfo, Document columnDocument) { + if (StringUtils.isNotBlank(columnInfo.getSamples())) { + return parseSamples(columnInfo.getSamples()); + } + if (columnDocument == null) { + return List.of(); + } + return parseSamples(String.valueOf(columnDocument.getMetadata().getOrDefault("samples", ""))); + } + private List parseSamples(String samples) { if (StringUtils.isBlank(samples)) { return List.of(); @@ -444,7 +483,11 @@ private List parseSamples(String samples) { } private List filterRelations(ExplorerContext context, String tableName) { - return context.relationsByTable().getOrDefault(normalizeTableName(tableName), List.of()); + return context.relationsByTable() + .getOrDefault(normalizeTableName(tableName), List.of()) + .stream() + .filter(relation -> isRelationVisible(context, relation)) + .toList(); } private Map toRelationEntry(UnifiedRelation relation) { @@ -461,18 +504,21 @@ private Map toRelationEntry(UnifiedRelation relation) { return relationEntry; } - private List buildUnifiedRelations(Set visibleTableNameSet, + private List buildUnifiedRelations(Map> visibleTablesByName, + Map> visibleTablesByLeafName, List physicalRelations, List logicalRelations) { Map relationMap = new LinkedHashMap<>(); for (ForeignKeyInfoBO physicalRelation : physicalRelations) { - UnifiedRelation relation = toUnifiedRelation(physicalRelation); - if (isVisibleRelation(visibleTableNameSet, relation)) { + UnifiedRelation relation = canonicalizeRelation(visibleTablesByName, visibleTablesByLeafName, + toUnifiedRelation(physicalRelation)); + if (relation != null) { mergeRelation(relationMap, relation); } } for (LogicalRelation logicalRelation : logicalRelations) { - UnifiedRelation relation = toUnifiedRelation(logicalRelation); - if (isVisibleRelation(visibleTableNameSet, relation)) { + UnifiedRelation relation = canonicalizeRelation(visibleTablesByName, visibleTablesByLeafName, + toUnifiedRelation(logicalRelation)); + if (relation != null) { mergeRelation(relationMap, relation); } } @@ -497,9 +543,18 @@ private UnifiedRelation toUnifiedRelation(LogicalRelation relation) { "logical", true, false); } - private boolean isVisibleRelation(Set visibleTableNameSet, UnifiedRelation relation) { - return visibleTableNameSet.contains(normalizeTableName(relation.sourceTable())) - && visibleTableNameSet.contains(normalizeTableName(relation.targetTable())); + private UnifiedRelation canonicalizeRelation(Map> visibleTablesByName, + Map> visibleTablesByLeafName, UnifiedRelation relation) { + Optional sourceTable = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + relation.sourceTable(), true); + Optional targetTable = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + relation.targetTable(), true); + if (sourceTable.isEmpty() || targetTable.isEmpty()) { + return null; + } + return new UnifiedRelation(sourceTable.get(), relation.sourceColumn(), targetTable.get(), relation.targetColumn(), + relation.relationType(), relation.description(), relation.sourceType(), relation.virtual(), + relation.declaredInDatabase()); } private void mergeRelation(Map relationMap, UnifiedRelation incoming) { @@ -568,6 +623,37 @@ private List> toRows(ResultSetBO resultSet) { }).toList(); } + private ResultSetBO filterResultSet(ResultSetBO resultSet, SqlGuardedQuery guardedQuery) { + Set allowedHeaders = guardedQuery.allowedResultHeaders(); + if (allowedHeaders == null || allowedHeaders.isEmpty()) { + return resultSet; + } + List originalColumns = Optional.ofNullable(resultSet.getColumn()).orElse(List.of()); + List keptIndexes = new ArrayList<>(); + List keptColumns = new ArrayList<>(); + for (int index = 0; index < originalColumns.size(); index++) { + String columnName = originalColumns.get(index); + if (allowedHeaders.contains(normalizeColumnName(columnName))) { + keptIndexes.add(index); + keptColumns.add(columnName); + } + } + if (keptIndexes.size() == originalColumns.size()) { + return resultSet; + } + List> filteredRows = Optional.ofNullable(resultSet.getData()).orElse(List.of()).stream().map(row -> { + Map filteredRow = new LinkedHashMap<>(); + for (Integer keptIndex : keptIndexes) { + String columnName = originalColumns.get(keptIndex); + filteredRow.put(columnName, row.get(columnName)); + } + return filteredRow; + }).toList(); + resultSet.setColumn(keptColumns); + resultSet.setData(filteredRows); + return resultSet; + } + private boolean containsQuery(Map table, String query) { return table.values() .stream() @@ -577,18 +663,175 @@ private boolean containsQuery(Map table, String query) { .anyMatch(value -> value.contains(query)); } - private String normalizeTableName(String tableName) { - String normalized = StringUtils.trimToEmpty(tableName); + private List applyVisibleColumnFilter(ExplorerContext context, String tableName, List columns) { + return Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(column -> isColumnVisible(context, tableName, column.getName())) + .toList(); + } + + private String resolvePreviewColumnSelection(ExplorerContext context, String tableName) { + List visibleColumns = context.visibleColumnsByTable().get(normalizeTableName(tableName)); + if (visibleColumns == null) { + return "*"; + } + if (visibleColumns.isEmpty()) { + throw new IllegalArgumentException("表 '%s' 当前没有可预览字段,请先调整字段级可见性配置".formatted(tableName)); + } + return visibleColumns.stream() + .map(columnName -> SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), columnName)) + .collect(Collectors.joining(", ")); + } + + private boolean isRelationVisible(ExplorerContext context, UnifiedRelation relation) { + return isColumnVisible(context, relation.sourceTable(), relation.sourceColumn()) + && isColumnVisible(context, relation.targetTable(), relation.targetColumn()); + } + + private boolean isColumnVisible(ExplorerContext context, String tableName, String columnName) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return true; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + return visibleColumns != null && visibleColumns.contains(normalizeColumnName(columnName)); + } + + private Map> buildVisibleColumnsByTable(AgentDatasource agentDatasource, + Map> visibleTablesByName, Map> visibleTablesByLeafName) { + Map> selectedColumns = Optional.ofNullable(agentDatasource.getSelectColumns()).orElse(Map.of()); + Map> visibleColumnsByTable = new LinkedHashMap<>(); + selectedColumns.forEach((tableName, columns) -> { + Optional resolvedTableName = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + tableName, true); + if (resolvedTableName.isEmpty()) { + return; + } + List sanitizedColumns = Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .collect(Collectors.toCollection(LinkedHashSet::new)) + .stream() + .toList(); + if (!sanitizedColumns.isEmpty()) { + visibleColumnsByTable.put(normalizeTableName(resolvedTableName.get()), List.copyOf(sanitizedColumns)); + } + }); + return visibleColumnsByTable; + } + + private Map> toImmutableListIndex(Map> source) { + Map> immutableIndex = new LinkedHashMap<>(); + source.forEach((key, value) -> immutableIndex.put(key, List.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private Map> toImmutableSetIndex(Map> source) { + Map> immutableIndex = new LinkedHashMap<>(); + source.forEach((key, value) -> immutableIndex.put(key, Set.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private Map> indexTablesByIdentity(List tableNames) { + return indexTables(tableNames, false); + } + + private Map> indexTablesByLeafName(List tableNames) { + return indexTables(tableNames, true); + } + + private Map> indexTables(List tableNames, boolean leafOnly) { + Map> index = new LinkedHashMap<>(); + for (String tableName : Optional.ofNullable(tableNames).orElse(List.of())) { + if (StringUtils.isBlank(tableName)) { + continue; + } + String normalizedKey = leafOnly ? normalizeTableLeafName(tableName) : normalizeTableName(tableName); + index.computeIfAbsent(normalizedKey, key -> new LinkedHashSet<>()).add(tableName); + } + Map> immutableIndex = new LinkedHashMap<>(); + index.forEach((key, value) -> immutableIndex.put(key, List.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private String resolveVisibleTableName(ExplorerContext context, String tableName) { + return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) + .orElseThrow(() -> buildInvisibleTableException(context, tableName)); + } + + private Optional findVisibleTableName(Map> visibleTablesByName, + Map> visibleTablesByLeafName, String tableName, boolean allowQualifiedFallback) { + String normalizedTableName = normalizeTableName(tableName); + List exactMatches = visibleTablesByName.getOrDefault(normalizedTableName, List.of()); + if (exactMatches.size() == 1) { + return Optional.of(exactMatches.get(0)); + } + if (exactMatches.size() > 1) { + throw new IllegalArgumentException( + "Table '%s' maps to multiple visible tables: %s".formatted(tableName, String.join(", ", exactMatches))); + } + if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { + return Optional.empty(); + } + List leafMatches = visibleTablesByLeafName.getOrDefault(normalizeTableLeafName(tableName), List.of()); + if (leafMatches.size() == 1) { + return Optional.of(leafMatches.get(0)); + } + if (leafMatches.size() > 1) { + throw new IllegalArgumentException("Table '%s' is ambiguous across visible tables: %s" + .formatted(tableName, String.join(", ", leafMatches))); + } + return Optional.empty(); + } + + private IllegalArgumentException buildInvisibleTableException(ExplorerContext context, String tableName) { + return new IllegalArgumentException("Table '%s' is not visible for current agent. Visible tables: %s" + .formatted(tableName, String.join(", ", context.visibleTables()))); + } + + private boolean isSelectedTable(ExplorerContext context, String tableName) { + if (context.explicitSelectedTables().isEmpty()) { + return true; + } + String normalizedTableName = normalizeTableName(tableName); + return context.explicitSelectedTables() + .stream() + .map(this::normalizeTableName) + .anyMatch(normalizedTableName::equals); + } + + private boolean isQualifiedIdentifier(String value) { + return normalizeIdentifier(value).contains("."); + } + + private String normalizeIdentifier(String value) { + String normalized = StringUtils.trimToEmpty(value); normalized = StringUtils.removeStart(normalized, "`"); normalized = StringUtils.removeEnd(normalized, "`"); normalized = StringUtils.removeStart(normalized, "\""); normalized = StringUtils.removeEnd(normalized, "\""); normalized = StringUtils.removeStart(normalized, "["); normalized = StringUtils.removeEnd(normalized, "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeTableName(String tableName) { + return normalizeIdentifier(tableName); + } + + private String normalizeTableLeafName(String tableName) { + String normalized = normalizeIdentifier(tableName); if (normalized.contains(".")) { - normalized = normalized.substring(normalized.lastIndexOf('.') + 1); + return normalized.substring(normalized.lastIndexOf('.') + 1); } - return normalized.toLowerCase(Locale.ROOT); + return normalized; + } + + private String normalizeColumnName(String columnName) { + return normalizeTableLeafName(columnName); } private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(ExplorerContext context, @@ -601,8 +844,10 @@ private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(Expl private record ExplorerContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, Accessor accessor, List visibleTables, Set visibleTableNameSet, - List explicitSelectedTables, List unifiedRelations, - Map> relationsByTable) { + Map> visibleTablesByName, Map> visibleTablesByLeafName, + List explicitSelectedTables, Map> visibleColumnsByTable, + Map> visibleColumnNameSetByTable, Set columnRestrictedTables, + List unifiedRelations, Map> relationsByTable) { } private record UnifiedRelation(String sourceTable, String sourceColumn, String targetTable, String targetColumn, @@ -610,4 +855,326 @@ private record UnifiedRelation(String sourceTable, String sourceColumn, String t boolean declaredInDatabase) { } + private record SqlGuardedQuery(String sql, Set referencedTables, Set allowedResultHeaders) { + } + + private record SqlValidationResult(Set referencedTables, Set allowedResultHeaders) { + } + + private record SourceBinding(String referenceName, String tableName) { + + boolean isBaseTable() { + return StringUtils.isNotBlank(tableName); + } + } + + private record SelectScope(Map sourcesByReference, List baseSources) { + + SourceBinding resolve(String reference) { + return sourcesByReference.get(reference); + } + } + + private final class SqlColumnAccessValidator { + + private final ExplorerContext context; + + private SqlColumnAccessValidator(ExplorerContext context) { + this.context = context; + } + + private SqlValidationResult validate(Select select) { + Set referencedTables = new LinkedHashSet<>(); + validateSelect(select, new LinkedHashSet<>(), referencedTables); + return new SqlValidationResult(Set.copyOf(referencedTables), extractAllowedResultHeaders(select)); + } + + private void validateSelect(Select select, Set cteNames, Set referencedTables) { + Set nextCteNames = new LinkedHashSet<>(cteNames); + if (select.getWithItemsList() != null) { + for (WithItem withItem : select.getWithItemsList()) { + if (withItem.getAlias() != null && StringUtils.isNotBlank(withItem.getAlias().getName())) { + nextCteNames.add(normalizeTableName(withItem.getAlias().getName())); + } + } + for (WithItem withItem : select.getWithItemsList()) { + validateSelect(withItem.getSelect(), nextCteNames, referencedTables); + } + } + if (select instanceof PlainSelect plainSelect) { + validatePlainSelect(plainSelect, nextCteNames, referencedTables); + return; + } + if (select instanceof SetOperationList setOperationList) { + for (Select childSelect : Optional.ofNullable(setOperationList.getSelects()).orElse(List.of())) { + validateSelect(childSelect, nextCteNames, referencedTables); + } + return; + } + if (select instanceof ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), nextCteNames, referencedTables); + return; + } + throw new IllegalArgumentException("当前 SQL 包含暂不支持的查询结构: " + select.getClass().getSimpleName()); + } + + private void validatePlainSelect(PlainSelect plainSelect, Set cteNames, Set referencedTables) { + SelectScope scope = buildScope(plainSelect, cteNames, referencedTables); + Set selectAliases = extractSelectAliases(plainSelect.getSelectItems()); + for (SelectItem selectItem : Optional.ofNullable(plainSelect.getSelectItems()).orElse(List.of())) { + validateExpression(selectItem.getExpression(), scope, cteNames, "SELECT", Set.of()); + } + validateExpression(plainSelect.getWhere(), scope, cteNames, "WHERE", Set.of()); + validateExpression(plainSelect.getHaving(), scope, cteNames, "HAVING", selectAliases); + validateExpression(plainSelect.getQualify(), scope, cteNames, "QUALIFY", selectAliases); + if (plainSelect.getGroupBy() != null && plainSelect.getGroupBy().getGroupByExpressions() != null) { + for (Object groupByExpression : Optional.ofNullable(plainSelect.getGroupBy().getGroupByExpressions().getExpressions()) + .orElse(List.of())) { + if (groupByExpression instanceof Expression expression) { + validateExpression(expression, scope, cteNames, "GROUP BY", selectAliases); + } + } + } + for (OrderByElement orderByElement : Optional.ofNullable(plainSelect.getOrderByElements()).orElse(List.of())) { + validateExpression(orderByElement.getExpression(), scope, cteNames, "ORDER BY", selectAliases); + } + } + + private SelectScope buildScope(PlainSelect plainSelect, Set cteNames, Set referencedTables) { + Map sourcesByReference = new LinkedHashMap<>(); + List baseSources = new ArrayList<>(); + addFromItemSources(plainSelect.getFromItem(), cteNames, referencedTables, sourcesByReference, baseSources); + for (Join join : Optional.ofNullable(plainSelect.getJoins()).orElse(List.of())) { + addFromItemSources(join.getRightItem(), cteNames, referencedTables, sourcesByReference, baseSources); + if (join.getUsingColumns() != null && !join.getUsingColumns().isEmpty()) { + throw new IllegalArgumentException("当前 SQL 使用了 JOIN ... USING 语法。字段级可见性校验要求改写成显式 ON alias.column = alias.column"); + } + for (Expression onExpression : Optional.ofNullable(join.getOnExpressions()).orElse(List.of())) { + validateExpression(onExpression, new SelectScope(Map.copyOf(sourcesByReference), List.copyOf(baseSources)), + cteNames, "JOIN ON", Set.of()); + } + } + return new SelectScope(Map.copyOf(sourcesByReference), List.copyOf(baseSources)); + } + + private void addFromItemSources(FromItem fromItem, Set cteNames, Set referencedTables, + Map sourcesByReference, List baseSources) { + if (fromItem == null) { + return; + } + if (fromItem instanceof Table table) { + String normalizedTableReference = normalizeTableName(extractTableReference(table)); + String aliasName = table.getAlias() == null ? null : normalizeTableName(table.getAlias().getName()); + if (cteNames.contains(normalizedTableReference)) { + registerSource(sourcesByReference, + new SourceBinding(StringUtils.defaultIfBlank(aliasName, normalizedTableReference), null)); + return; + } + String resolvedTableName = resolveVisibleTableName(context, extractTableReference(table)); + referencedTables.add(normalizeTableName(resolvedTableName)); + SourceBinding sourceBinding = new SourceBinding(StringUtils.defaultIfBlank(aliasName, + normalizeTableName(resolvedTableName)), resolvedTableName); + registerSource(sourcesByReference, sourceBinding); + registerSource(sourcesByReference, + new SourceBinding(normalizeTableName(resolvedTableName), resolvedTableName)); + String tableLeafName = normalizeTableLeafName(resolvedTableName); + if (!tableLeafName.equals(normalizeTableName(resolvedTableName))) { + registerSource(sourcesByReference, new SourceBinding(tableLeafName, resolvedTableName)); + } + baseSources.add(sourceBinding); + return; + } + if (fromItem instanceof LateralSubSelect lateralSubSelect) { + validateSelect(lateralSubSelect.getSelect(), cteNames, referencedTables); + registerDerivedSource(lateralSubSelect, sourcesByReference); + return; + } + if (fromItem instanceof ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), cteNames, referencedTables); + registerDerivedSource(parenthesedSelect, sourcesByReference); + return; + } + if (fromItem instanceof ParenthesedFromItem parenthesedFromItem) { + Map nestedSources = new LinkedHashMap<>(); + List nestedBaseSources = new ArrayList<>(); + addFromItemSources(parenthesedFromItem.getFromItem(), cteNames, referencedTables, nestedSources, nestedBaseSources); + for (Join join : Optional.ofNullable(parenthesedFromItem.getJoins()).orElse(List.of())) { + addFromItemSources(join.getRightItem(), cteNames, referencedTables, nestedSources, nestedBaseSources); + if (join.getUsingColumns() != null && !join.getUsingColumns().isEmpty()) { + throw new IllegalArgumentException("当前 SQL 使用了 JOIN ... USING 语法。字段级可见性校验要求改写成显式 ON alias.column = alias.column"); + } + for (Expression onExpression : Optional.ofNullable(join.getOnExpressions()).orElse(List.of())) { + validateExpression(onExpression, + new SelectScope(Map.copyOf(nestedSources), List.copyOf(nestedBaseSources)), + cteNames, "JOIN ON", Set.of()); + } + } + if (parenthesedFromItem.getAlias() != null && StringUtils.isNotBlank(parenthesedFromItem.getAlias().getName())) { + registerDerivedSource(parenthesedFromItem, sourcesByReference); + } + else { + nestedSources.values().forEach(binding -> registerSource(sourcesByReference, binding)); + baseSources.addAll(nestedBaseSources); + } + return; + } + if (fromItem instanceof TableFunction tableFunction) { + registerDerivedSource(tableFunction, sourcesByReference); + return; + } + throw new IllegalArgumentException("当前 SQL 包含暂不支持的 FROM 结构: " + fromItem.getClass().getSimpleName()); + } + + private void registerDerivedSource(FromItem fromItem, Map sourcesByReference) { + if (fromItem.getAlias() == null || StringUtils.isBlank(fromItem.getAlias().getName())) { + return; + } + registerSource(sourcesByReference, new SourceBinding(normalizeTableName(fromItem.getAlias().getName()), null)); + } + + private void registerSource(Map sourcesByReference, SourceBinding sourceBinding) { + SourceBinding existingBinding = sourcesByReference.get(sourceBinding.referenceName()); + if (existingBinding != null && !Objects.equals(existingBinding.tableName(), sourceBinding.tableName())) { + throw new IllegalArgumentException("Table reference '%s' is ambiguous in current SQL scope; please use aliases" + .formatted(sourceBinding.referenceName())); + } + sourcesByReference.put(sourceBinding.referenceName(), sourceBinding); + } + + private void validateExpression(Expression expression, SelectScope scope, Set cteNames, String clause, + Set allowedAliases) { + if (expression == null) { + return; + } + expression.accept(new ExpressionVisitorAdapter() { + @Override + public void visit(Function function) { + if (function.isAllColumns() && !"count".equalsIgnoreCase(function.getName())) { + throw new IllegalArgumentException("子句 %s 中检测到 %s(*)。字段级可见性校验禁止使用除 COUNT(*) 外的星号聚合,请显式列出字段" + .formatted(clause, StringUtils.defaultIfBlank(function.getName(), "function"))); + } + super.visit(function); + } + + @Override + public void visit(AllColumns allColumns) { + throw new IllegalArgumentException("子句 %s 中检测到 SELECT *。请改成显式列名,避免越权读取隐藏字段".formatted(clause)); + } + + @Override + public void visit(AllTableColumns allTableColumns) { + throw new IllegalArgumentException( + "子句 %s 中检测到 %s.*。请改成显式列名,避免越权读取隐藏字段".formatted(clause, allTableColumns.getTable())); + } + + @Override + public void visit(Column column) { + validateColumnReference(scope, clause, column, allowedAliases); + } + + @Override + public void visit(ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), cteNames, new LinkedHashSet<>()); + } + + @Override + public void visit(Select select) { + validateSelect(select, cteNames, new LinkedHashSet<>()); + } + }); + } + + private void validateColumnReference(SelectScope scope, String clause, Column column, Set allowedAliases) { + String normalizedColumnName = normalizeColumnName(column.getColumnName()); + if (StringUtils.isBlank(normalizedColumnName)) { + throw new IllegalArgumentException("子句 %s 中存在无法识别的字段引用: %s".formatted(clause, column)); + } + Table table = column.getTable(); + String tableReference = table == null ? StringUtils.EMPTY : normalizeTableName(extractTableReference(table)); + if (StringUtils.isBlank(tableReference)) { + if (allowedAliases.contains(normalizedColumnName)) { + return; + } + if (scope.baseSources().isEmpty()) { + return; + } + if (scope.baseSources().size() > 1) { + throw new IllegalArgumentException( + "子句 %s 中的字段 '%s' 没有带表前缀。当前 SQL 涉及多张基础表,无法安全判断字段归属,请改成 alias.%s" + .formatted(clause, column.getColumnName(), column.getColumnName())); + } + SourceBinding sourceBinding = scope.baseSources().get(0); + assertVisibleColumn(sourceBinding.tableName(), column.getColumnName(), clause, column.toString()); + return; + } + SourceBinding sourceBinding = scope.resolve(tableReference); + if (sourceBinding == null) { + throw new IllegalArgumentException( + "子句 %s 中引用了未知表/别名 '%s'。请检查 SQL 中的表别名是否和 FROM/JOIN 定义一致".formatted(clause, table.getName())); + } + if (!sourceBinding.isBaseTable()) { + return; + } + assertVisibleColumn(sourceBinding.tableName(), column.getColumnName(), clause, column.toString()); + } + + private void assertVisibleColumn(String tableName, String columnName, String clause, String expression) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + if (visibleColumns == null || !visibleColumns.contains(normalizeColumnName(columnName))) { + String visibleColumnSummary = Optional.ofNullable(context.visibleColumnsByTable().get(normalizedTableName)) + .orElse(List.of()) + .stream() + .collect(Collectors.joining(", ")); + throw new IllegalArgumentException( + "子句 %s 中的字段引用 '%s' 不被允许。表 '%s' 已启用字段级可见性控制,仅允许字段: [%s]" + .formatted(clause, expression, tableName, visibleColumnSummary)); + } + } + + private Set extractSelectAliases(List> selectItems) { + return Optional.ofNullable(selectItems) + .orElse(List.of()) + .stream() + .map(SelectItem::getAlias) + .filter(alias -> alias != null && StringUtils.isNotBlank(alias.getName())) + .map(alias -> normalizeColumnName(alias.getName())) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + private Set extractAllowedResultHeaders(Select select) { + if (!(select instanceof PlainSelect plainSelect)) { + return null; + } + Set headers = new LinkedHashSet<>(); + for (SelectItem selectItem : Optional.ofNullable(plainSelect.getSelectItems()).orElse(List.of())) { + Expression expression = selectItem.getExpression(); + if (expression instanceof AllColumns || expression instanceof AllTableColumns) { + return null; + } + if (selectItem.getAlias() != null && StringUtils.isNotBlank(selectItem.getAlias().getName())) { + headers.add(normalizeColumnName(selectItem.getAlias().getName())); + continue; + } + if (expression instanceof Column column) { + headers.add(normalizeColumnName(column.getColumnName())); + continue; + } + return null; + } + return headers; + } + + private String extractTableReference(Table table) { + String fullyQualifiedName = table == null ? StringUtils.EMPTY : table.getFullyQualifiedName(); + if (StringUtils.isNotBlank(fullyQualifiedName)) { + return fullyQualifiedName; + } + return table == null ? StringUtils.EMPTY : table.getName(); + } + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index a470267da..6575492f4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -141,7 +141,8 @@ private String buildDescription(Datasource datasource, AgentDatasource agentData 4. Treat the unified relations field as the primary source for table-to-table relationship reasoning and join planning. 5. The foreignKeys field inside table metadata is kept only for compatibility; prefer relations for agent reasoning. 6. Recommended call order: LIST_TABLES -> GET_TABLE_SCHEMA -> GET_RELATED_TABLES -> PREVIEW_ROWS -> SEARCH. - 7. %s + 7. Never infer hidden fields from visible values. For example, do not derive a username or person name from an email local-part, ID, code, or alias. + 8. %s """.formatted(datasource.getName(), datasource.getType(), visibleTables); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java index f9962af1e..d4dbfafd9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java @@ -15,23 +15,35 @@ */ package com.alibaba.cloud.ai.dataagent.controller; +import com.alibaba.cloud.ai.dataagent.dto.datasource.TableColumnsSelectionDTO; import com.alibaba.cloud.ai.dataagent.dto.datasource.ToggleDatasourceDTO; +import com.alibaba.cloud.ai.dataagent.dto.datasource.UpdateDatasourceColumnsDTO; import com.alibaba.cloud.ai.dataagent.dto.datasource.UpdateDatasourceTablesDTO; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.exception.InternalServerException; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; +import jakarta.validation.Valid; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.*; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; /** - * Agent Schema Initialization Controller Handles agent's database Schema initialization - * to vector storage + * Agent datasource controller. */ @Slf4j @RestController @@ -42,60 +54,49 @@ public class AgentDatasourceController { private final AgentDatasourceService agentDatasourceService; - /** - * Initialize agent's database Schema to vector storage Corresponds to the "Initialize - * Information Source" function on the frontend - */ @PostMapping("/init") public ApiResponse initSchema(@PathVariable Long agentId) { - // 防止前端恶意请求,dto数据应该在后端获取 try { AgentDatasource agentDatasource = agentDatasourceService.getCurrentAgentDatasource(agentId); log.info("Initializing schema for agent: {}", agentId); - // Extract data source ID and table list from request Integer datasourceId = agentDatasource.getDatasourceId(); List tables = Optional.ofNullable(agentDatasource.getSelectTables()).orElse(List.of()); - // Validate request parameters if (datasourceId == null) { - throw new InvalidInputException("数据源ID不能为空"); + throw new InvalidInputException("datasourceId cannot be null"); } - if (tables.isEmpty()) { - throw new InvalidInputException("表列表不能为空"); + throw new InvalidInputException("tables cannot be empty"); } - // Execute Schema initialization - Boolean result = agentDatasourceService.initializeSchemaForAgentWithDatasource(agentId, datasourceId, - tables); - - if (result) { + Boolean result = agentDatasourceService.initializeSchemaForAgentWithDatasource(agentId, datasourceId, tables); + if (Boolean.TRUE.equals(result)) { log.info("Successfully initialized schema for agent: {}, tables: {}", agentId, tables.size()); - return ApiResponse.success("Schema初始化成功"); - } - else { - throw new InternalServerException("Schema初始化失败"); + return ApiResponse.success("Schema initialized successfully"); } + throw new InternalServerException("Schema initialization failed"); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { log.error("Failed to initialize schema for agent: {}", agentId, e); - throw new InternalServerException("Schema初始化失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Schema initialization failed: %s".formatted(e.getMessage())); } } - /** Get list of data sources configured for agent */ @GetMapping public ApiResponse> getAgentDatasource(@PathVariable Long agentId) { try { log.info("Getting datasources for agent: {}", agentId); List datasources = agentDatasourceService.getAgentDatasource(agentId); log.info("Successfully retrieved {} datasources for agent: {}", datasources.size(), agentId); - return ApiResponse.success("操作成功", datasources); + return ApiResponse.success("success", datasources); } catch (Exception e) { log.error("Failed to get datasources for agent: {}", agentId, e); - throw new InvalidInputException("获取数据源失败:%s".formatted(e.getMessage()), List.of()); + throw new InvalidInputException("Failed to get datasources: %s".formatted(e.getMessage()), List.of()); } } @@ -104,59 +105,110 @@ public ApiResponse getActiveAgentDatasource(@PathVariable Long try { log.info("Getting active datasource for agent: {}", agentId); AgentDatasource datasource = agentDatasourceService.getCurrentAgentDatasource(agentId); - return ApiResponse.success("操作成功", datasource); + return ApiResponse.success("success", datasource); } catch (Exception e) { log.error("Failed to get active datasource for agent: {}", agentId, e); - throw new InvalidInputException("获取数据源失败:%s".formatted(e.getMessage()), List.of()); + throw new InvalidInputException("Failed to get active datasource: %s".formatted(e.getMessage()), List.of()); } } - /** Add data source for agent */ @PostMapping("/{datasourceId}") public ApiResponse addDatasourceToAgent(@PathVariable Long agentId, @PathVariable Integer datasourceId) { try { if (datasourceId == null) { - throw new InvalidInputException("数据源ID不能为空"); + throw new InvalidInputException("datasourceId cannot be null"); } - AgentDatasource agentDatasource = agentDatasourceService.addDatasourceToAgent(agentId, datasourceId); - return ApiResponse.success("数据源添加成功", agentDatasource); + return ApiResponse.success("Datasource added successfully", agentDatasource); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { - throw new InternalServerException("数据源添加失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Failed to add datasource: %s".formatted(e.getMessage())); } } - // 更新选择的数据表 @PostMapping("/tables") - public ApiResponse updateDatasourceTables(@PathVariable Long agentId, + public ApiResponse updateDatasourceTables(@PathVariable Long agentId, @RequestBody @Validated UpdateDatasourceTablesDTO dto) { try { dto.setTables(Optional.ofNullable(dto.getTables()).orElse(List.of())); - agentDatasourceService.updateDatasourceTables(agentId, dto.getDatasourceId(), dto.getTables()); - return ApiResponse.success("更新成功"); + AgentDatasource agentDatasource = agentDatasourceService.updateDatasourceTables(agentId, dto.getDatasourceId(), + dto.getTables()); + return ApiResponse.success("Update successful", agentDatasource); + } + catch (IllegalArgumentException e) { + log.warn("Invalid datasource tables update request, agentId={}, datasourceId={}, message={}", agentId, + dto.getDatasourceId(), e.getMessage()); + throw new InvalidInputException(e.getMessage()); } catch (Exception e) { - log.error("Error: ", e); - throw new InternalServerException("更新失败:%s".formatted(e.getMessage())); + log.error("Error updating datasource tables", e); + throw new InternalServerException("Update failed: %s".formatted(e.getMessage())); + } + } + + @PostMapping("/columns") + public ApiResponse updateDatasourceColumns(@PathVariable Long agentId, + @RequestBody @Valid UpdateDatasourceColumnsDTO dto) { + try { + List tableSelections = Optional.ofNullable(dto.getTables()).orElse(List.of()); + Map> columnsByTable = new LinkedHashMap<>(); + for (TableColumnsSelectionDTO tableSelection : tableSelections) { + if (tableSelection == null) { + continue; + } + columnsByTable.put(tableSelection.getTableName(), + Optional.ofNullable(tableSelection.getColumns()).orElse(List.of())); + } + AgentDatasource agentDatasource = agentDatasourceService.updateDatasourceColumns(agentId, dto.getDatasourceId(), + columnsByTable); + return ApiResponse.success("Update successful", agentDatasource); + } + catch (IllegalArgumentException e) { + log.warn("Invalid datasource columns update request, agentId={}, datasourceId={}, message={}", agentId, + dto.getDatasourceId(), e.getMessage()); + throw new InvalidInputException(e.getMessage()); + } + catch (Exception e) { + log.error("Error updating datasource columns", e); + throw new InternalServerException("Update failed: %s".formatted(e.getMessage())); + } + } + + @GetMapping("/{datasourceId}/tables/{tableName}/columns") + public ApiResponse> getVisibleTableColumns(@PathVariable Long agentId, @PathVariable Integer datasourceId, + @PathVariable String tableName) { + try { + List columns = agentDatasourceService.getVisibleTableColumns(agentId, datasourceId, tableName); + return ApiResponse.success("success", columns); + } + catch (IllegalArgumentException e) { + log.warn("Invalid visible columns request, agentId={}, datasourceId={}, tableName={}, message={}", agentId, + datasourceId, tableName, e.getMessage()); + throw new InvalidInputException(e.getMessage()); + } + catch (Exception e) { + log.error("Error loading visible columns, agentId={}, datasourceId={}, tableName={}", agentId, datasourceId, + tableName, e); + throw new InternalServerException("Load columns failed: %s".formatted(e.getMessage())); } } - /** Remove data source association from agent */ @DeleteMapping("/{datasourceId}") public ApiResponse removeDatasourceFromAgent(@PathVariable Long agentId, @PathVariable Integer datasourceId) { try { agentDatasourceService.removeDatasourceFromAgent(agentId, datasourceId); - return ApiResponse.success("数据源已移除"); + return ApiResponse.success("Datasource removed"); } catch (Exception e) { - throw new InternalServerException("移除失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Remove failed: %s".formatted(e.getMessage())); } } - /** 启用/禁用智能体的数据源 */ @PutMapping("/toggle") public ApiResponse toggleDatasourceForAgent(@PathVariable Long agentId, @RequestBody ToggleDatasourceDTO dto) { @@ -164,14 +216,17 @@ public ApiResponse toggleDatasourceForAgent(@PathVariable Long Boolean isActive = dto.getIsActive(); Integer datasourceId = dto.getDatasourceId(); if (isActive == null || datasourceId == null) { - throw new InvalidInputException("激活状态不能为空"); + throw new InvalidInputException("isActive and datasourceId cannot be null"); } AgentDatasource agentDatasource = agentDatasourceService.toggleDatasourceForAgent(agentId, dto.getDatasourceId(), isActive); - return ApiResponse.success(isActive ? "数据源已启用" : "数据源已禁用", agentDatasource); + return ApiResponse.success(isActive ? "Datasource enabled" : "Datasource disabled", agentDatasource); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { - throw new InternalServerException("操作失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Operation failed: %s".formatted(e.getMessage())); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java index 865aed9d1..155071c0d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java @@ -18,14 +18,17 @@ import com.alibaba.cloud.ai.dataagent.exception.InternalServerException; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; +import org.springframework.web.bind.MethodArgumentNotValidException; +import org.springframework.web.bind.support.WebExchangeBindException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.bind.annotation.RestControllerAdvice; /** - * 全局异常处理器 (WebFlux 版本) + * Global exception handler. */ @Slf4j @RestControllerAdvice @@ -38,6 +41,26 @@ public ApiResponse handleInvalidInputException(InvalidInputException e) return ApiResponse.error(e.getMessage(), e.getData()); } + @ExceptionHandler(MethodArgumentNotValidException.class) + @ResponseStatus(HttpStatus.BAD_REQUEST) + public ApiResponse handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { + String message = buildValidationMessage(e.getBindingResult().getFieldErrors().stream() + .map(error -> error.getField() + ": " + error.getDefaultMessage()) + .collect(Collectors.toList())); + log.warn("Method argument not valid: {}", message); + return ApiResponse.error(message); + } + + @ExceptionHandler(WebExchangeBindException.class) + @ResponseStatus(HttpStatus.BAD_REQUEST) + public ApiResponse handleWebExchangeBindException(WebExchangeBindException e) { + String message = buildValidationMessage(e.getFieldErrors().stream() + .map(error -> error.getField() + ": " + error.getDefaultMessage()) + .collect(Collectors.toList())); + log.warn("Web exchange bind not valid: {}", message); + return ApiResponse.error(message); + } + @ExceptionHandler(InternalServerException.class) @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) public ApiResponse handleInternalServerException(InternalServerException e) { @@ -49,7 +72,15 @@ public ApiResponse handleInternalServerException(InternalServerException @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) public ApiResponse handleGenericException(Exception e) { log.error("Unexpected error: {}", e.getMessage(), e); - return ApiResponse.error("服务器内部错误"); + return ApiResponse.error("Internal server error"); + } + + private String buildValidationMessage(java.util.List messages) { + String message = messages.stream().filter(item -> item != null && !item.isBlank()).collect(Collectors.joining("; ")); + if (message.isBlank()) { + return "Request validation failed"; + } + return message; } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java index 48f766e2f..6d4a95038 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.List; +import java.util.Map; import java.util.Objects; public class SchemaInitRequest implements Serializable { @@ -27,6 +28,8 @@ public class SchemaInitRequest implements Serializable { private List tables; + private Map> visibleColumnsByTable; + public DbConfigBO getDbConfig() { return dbConfig; } @@ -43,9 +46,18 @@ public void setTables(List tables) { this.tables = tables; } + public Map> getVisibleColumnsByTable() { + return visibleColumnsByTable; + } + + public void setVisibleColumnsByTable(Map> visibleColumnsByTable) { + this.visibleColumnsByTable = visibleColumnsByTable; + } + @Override public String toString() { - return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", tables=" + tables + '}'; + return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", tables=" + tables + ", visibleColumnsByTable=" + + visibleColumnsByTable + '}'; } @Override @@ -55,12 +67,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; SchemaInitRequest that = (SchemaInitRequest) o; - return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(tables, that.tables); + return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(tables, that.tables) + && Objects.equals(visibleColumnsByTable, that.visibleColumnsByTable); } @Override public int hashCode() { - return Objects.hash(dbConfig, tables); + return Objects.hash(dbConfig, tables, visibleColumnsByTable); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java new file mode 100644 index 000000000..a3d531bf7 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.dto.datasource; + +import jakarta.validation.constraints.NotBlank; +import java.util.List; +import lombok.Data; + +@Data +public class TableColumnsSelectionDTO { + + @NotBlank(message = "tableName cannot be blank") + private String tableName; + + private List columns; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java new file mode 100644 index 000000000..7bc5247f4 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.dto.datasource; + +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import java.util.List; +import lombok.Data; + +@Data +public class UpdateDatasourceColumnsDTO { + + @NotNull(message = "datasourceId cannot be null") + private Integer datasourceId; + + @Valid + private List tables; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java index fdd7ae0b5..265417106 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonFormat; import java.time.LocalDateTime; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -50,6 +51,9 @@ public class AgentDatasource { // 当前数据源选中的表 private List selectTables; + // 当前数据源按表配置的字段白名单;未配置的表默认整表可见 + private Map> selectColumns; + public AgentDatasource(Long agentId, Integer datasourceId) { this.agentId = agentId; this.datasourceId = datasourceId; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java new file mode 100644 index 000000000..3bf94bf55 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.entity; + +import com.fasterxml.jackson.annotation.JsonFormat; +import java.time.LocalDateTime; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.format.annotation.DateTimeFormat; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AgentDatasourceColumn { + + private Integer id; + + private Integer agentDatasourceId; + + private String tableName; + + private String columnName; + + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss") + @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss") + private LocalDateTime createTime; + + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss") + @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss") + private LocalDateTime updateTime; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java new file mode 100644 index 000000000..d03683d04 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.mapper; + +import com.alibaba.cloud.ai.dataagent.entity.AgentDatasourceColumn; +import java.util.List; +import org.apache.ibatis.annotations.Delete; +import org.apache.ibatis.annotations.Insert; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; + +@Mapper +public interface AgentDatasourceColumnsMapper { + + @Select("SELECT * FROM agent_datasource_columns WHERE agent_datasource_id = #{agentDatasourceId} ORDER BY table_name, column_name") + List getAgentDatasourceColumns(@Param("agentDatasourceId") int agentDatasourceId); + + @Delete("DELETE FROM agent_datasource_columns WHERE agent_datasource_id = #{agentDatasourceId}") + int removeAllColumns(@Param("agentDatasourceId") int agentDatasourceId); + + @Delete("") + int removeColumnsOutsideTables(@Param("agentDatasourceId") int agentDatasourceId, + @Param("tables") List tables); + + @Insert("") + int insertColumns(@Param("rows") List rows); + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java index 742928589..e0294f185 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java @@ -17,6 +17,7 @@ import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import java.util.List; +import java.util.Map; public interface AgentDatasourceService { @@ -38,6 +39,11 @@ default AgentDatasource getCurrentAgentDatasource(Long agentId) { AgentDatasource toggleDatasourceForAgent(Long agentId, Integer datasourceId, Boolean isActive); - void updateDatasourceTables(Long agentId, Integer datasourceId, List tables); + AgentDatasource updateDatasourceTables(Long agentId, Integer datasourceId, List tables); + + AgentDatasource updateDatasourceColumns(Long agentId, Integer datasourceId, Map> columnsByTable) + throws Exception; + + List getVisibleTableColumns(Long agentId, Integer datasourceId, String tableName) throws Exception; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java index 47a7b9678..31ae3e8c4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java @@ -18,14 +18,22 @@ import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; import com.alibaba.cloud.ai.dataagent.dto.datasource.SchemaInitRequest; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; +import com.alibaba.cloud.ai.dataagent.entity.AgentDatasourceColumn; import com.alibaba.cloud.ai.dataagent.entity.Datasource; import com.alibaba.cloud.ai.dataagent.mapper.AgentDatasourceMapper; +import com.alibaba.cloud.ai.dataagent.mapper.AgentDatasourceColumnsMapper; import com.alibaba.cloud.ai.dataagent.mapper.AgentDatasourceTablesMapper; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; import com.alibaba.cloud.ai.dataagent.service.schema.SchemaService; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -45,6 +53,8 @@ public class AgentDatasourceServiceImpl implements AgentDatasourceService { private final AgentDatasourceTablesMapper tablesMapper; + private final AgentDatasourceColumnsMapper columnsMapper; + @Override public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer datasourceId, List tables) { Assert.notNull(agentId, "Agent ID cannot be null"); @@ -63,11 +73,17 @@ public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer data // Create database configuration DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); + AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + if (agentDatasource == null) { + throw new RuntimeException("Agent datasource relation not found with agentId=%s, datasourceId=%s" + .formatted(agentId, datasourceId)); + } // Create SchemaInitRequest SchemaInitRequest schemaInitRequest = new SchemaInitRequest(); schemaInitRequest.setDbConfig(dbConfig); schemaInitRequest.setTables(tables); + schemaInitRequest.setVisibleColumnsByTable(loadSelectedColumns(agentDatasource.getId())); log.info("Created SchemaInitRequest for agent: {}, dbConfig: {}, tables: {}", agentIdStr, dbConfig, tables); @@ -86,17 +102,8 @@ public List getAgentDatasource(Long agentId) { Assert.notNull(agentId, "Agent ID cannot be null"); List adentDatasources = agentDatasourceMapper.selectByAgentIdWithDatasource(agentId); - // Manually fill in the data source information (since MyBatis Plus does not - // directly support complex join query result mapping) for (AgentDatasource agentDatasource : adentDatasources) { - if (agentDatasource.getDatasourceId() != null) { - Datasource datasource = datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); - agentDatasource.setDatasource(datasource); - } - // 获取选中的数据表 - int id = agentDatasource.getId(); - List tables = tablesMapper.getAgentDatasourceTables(id); - agentDatasource.setSelectTables(Optional.ofNullable(tables).orElse(List.of())); + enrichAgentDatasource(agentDatasource); } return adentDatasources; @@ -119,6 +126,7 @@ public AgentDatasource addDatasourceToAgent(Long agentId, Integer datasourceId) // 删除已有的表 tablesMapper.removeAllTables(existing.getId()); + columnsMapper.removeAllColumns(existing.getId()); // Query and return the updated association result = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); @@ -131,6 +139,7 @@ public AgentDatasource addDatasourceToAgent(Long agentId, Integer datasourceId) result = agentDatasource; } result.setSelectTables(List.of()); + result.setSelectColumns(Map.of()); return result; } @@ -157,12 +166,14 @@ public AgentDatasource toggleDatasourceForAgent(Long agentId, Integer datasource } // Return the updated association record - return agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + enrichAgentDatasource(agentDatasource); + return agentDatasource; } @Override @Transactional - public void updateDatasourceTables(Long agentId, Integer datasourceId, List tables) { + public AgentDatasource updateDatasourceTables(Long agentId, Integer datasourceId, List tables) { if (agentId == null || datasourceId == null || tables == null) { throw new IllegalArgumentException("参数不能为空"); } @@ -170,12 +181,252 @@ public void updateDatasourceTables(Long agentId, Integer datasourceId, List normalizedTables; + try { + TableResolutionIndex datasourceTableIndex = buildTableResolutionIndex( + datasourceService.getDatasourceTables(datasourceId)); + normalizedTables = sanitizeRequestedTables(tables, datasourceTableIndex); + } + catch (Exception ex) { + throw new IllegalArgumentException("Failed to validate datasource tables: %s".formatted(ex.getMessage()), ex); + } + if (normalizedTables.isEmpty()) { tablesMapper.removeAllTables(datasource.getId()); + columnsMapper.removeAllColumns(datasource.getId()); } else { - tablesMapper.updateAgentDatasourceTables(datasource.getId(), tables); + tablesMapper.updateAgentDatasourceTables(datasource.getId(), normalizedTables); + columnsMapper.removeColumnsOutsideTables(datasource.getId(), normalizedTables); + } + return refreshAgentDatasource(agentId, datasourceId); + } + + @Override + @Transactional + public AgentDatasource updateDatasourceColumns(Long agentId, Integer datasourceId, + Map> columnsByTable) + throws Exception { + if (agentId == null || datasourceId == null || columnsByTable == null) { + throw new IllegalArgumentException("参数不能为空"); + } + AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + if (agentDatasource == null) { + throw new IllegalArgumentException("未找到对应的数据源关联记录"); + } + + TableResolutionIndex allowedTables = loadAllowedTables(agentDatasource, datasourceId); + Map> sanitizedColumnsByTable = sanitizeColumnsByTable(datasourceId, columnsByTable, + allowedTables); + + columnsMapper.removeAllColumns(agentDatasource.getId()); + List rows = new ArrayList<>(); + sanitizedColumnsByTable.forEach((tableName, columns) -> columns.forEach(columnName -> rows + .add(new AgentDatasourceColumn(null, agentDatasource.getId(), tableName, columnName, null, null)))); + if (!rows.isEmpty()) { + columnsMapper.insertColumns(rows); + } + return refreshAgentDatasource(agentId, datasourceId); + } + + @Override + public List getVisibleTableColumns(Long agentId, Integer datasourceId, String tableName) throws Exception { + if (agentId == null || datasourceId == null || tableName == null || tableName.isBlank()) { + throw new IllegalArgumentException("agentId, datasourceId and tableName cannot be blank"); + } + AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + if (agentDatasource == null) { + throw new IllegalArgumentException("鏈壘鍒板搴旂殑鏁版嵁婧愬叧鑱旇褰?"); + } + TableResolutionIndex allowedTables = loadAllowedTables(agentDatasource, datasourceId); + String actualTableName = resolveTableName(tableName, allowedTables, false); + if (actualTableName == null) { + throw new IllegalArgumentException( + "Table '%s' does not exist or is not visible in current agent datasource".formatted(tableName)); + } + return datasourceService.getTableColumns(datasourceId, actualTableName); + } + + private void enrichAgentDatasource(AgentDatasource agentDatasource) { + if (agentDatasource == null) { + return; + } + if (agentDatasource.getDatasourceId() != null && agentDatasource.getDatasource() == null) { + Datasource datasource = datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); + agentDatasource.setDatasource(datasource); + } + List tables = tablesMapper.getAgentDatasourceTables(agentDatasource.getId()); + agentDatasource.setSelectTables(Optional.ofNullable(tables).orElse(List.of())); + agentDatasource.setSelectColumns(loadSelectedColumns(agentDatasource.getId())); + } + + private AgentDatasource refreshAgentDatasource(Long agentId, Integer datasourceId) { + AgentDatasource refreshed = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + enrichAgentDatasource(refreshed); + return refreshed; + } + + private Map> loadSelectedColumns(int agentDatasourceId) { + List rows = Optional.ofNullable(columnsMapper.getAgentDatasourceColumns(agentDatasourceId)) + .orElse(List.of()); + Map> columnsByTable = new LinkedHashMap<>(); + for (AgentDatasourceColumn row : rows) { + if (row == null) { + continue; + } + columnsByTable.computeIfAbsent(row.getTableName(), key -> new ArrayList<>()).add(row.getColumnName()); + } + columnsByTable.replaceAll((tableName, columns) -> List.copyOf(columns)); + return Map.copyOf(columnsByTable); + } + + private TableResolutionIndex loadAllowedTables(AgentDatasource agentDatasource, Integer datasourceId) throws Exception { + List datasourceTables = datasourceService.getDatasourceTables(datasourceId); + TableResolutionIndex datasourceTableIndex = buildTableResolutionIndex(datasourceTables); + List selectedTables = Optional.ofNullable(tablesMapper.getAgentDatasourceTables(agentDatasource.getId())) + .orElse(List.of()); + List visibleTables = selectedTables.isEmpty() ? datasourceTables + : sanitizeRequestedTables(selectedTables, datasourceTableIndex, true); + return buildTableResolutionIndex(visibleTables); + } + + private Map> sanitizeColumnsByTable(Integer datasourceId, Map> columnsByTable, + TableResolutionIndex allowedTables) throws Exception { + Map> sanitized = new LinkedHashMap<>(); + for (Map.Entry> entry : columnsByTable.entrySet()) { + String requestedTableName = entry.getKey(); + String actualTableName = resolveTableName(requestedTableName, allowedTables, false); + if (actualTableName == null) { + throw new IllegalArgumentException("字段白名单配置包含当前 agent 不可见的数据表: " + requestedTableName); + } + + Map actualColumns = datasourceService.getTableColumns(datasourceId, actualTableName) + .stream() + .collect(LinkedHashMap::new, (map, columnName) -> map.put(normalizeIdentifier(columnName), columnName), + Map::putAll); + LinkedHashSet dedupedColumns = new LinkedHashSet<>(); + for (String requestedColumn : Optional.ofNullable(entry.getValue()).orElse(List.of())) { + String normalizedColumn = normalizeIdentifier(requestedColumn); + String actualColumnName = actualColumns.get(normalizedColumn); + if (actualColumnName == null) { + throw new IllegalArgumentException( + "表 '%s' 中不存在字段 '%s',无法保存字段级可见性配置".formatted(actualTableName, requestedColumn)); + } + dedupedColumns.add(actualColumnName); + } + if (!dedupedColumns.isEmpty()) { + sanitized.put(actualTableName, List.copyOf(dedupedColumns)); + } + } + return sanitized; + } + + private List normalizeTableNames(List tables) { + return tables.stream() + .map(String::trim) + .filter(tableName -> !tableName.isEmpty()) + .collect(Collectors.toCollection(LinkedHashSet::new)) + .stream() + .toList(); + } + + private List sanitizeRequestedTables(List tables, TableResolutionIndex tableIndex) { + return sanitizeRequestedTables(tables, tableIndex, false); + } + + private List sanitizeRequestedTables(List tables, TableResolutionIndex tableIndex, + boolean allowQualifiedFallback) { + LinkedHashSet resolvedTables = new LinkedHashSet<>(); + for (String tableName : normalizeTableNames(tables)) { + String resolvedTableName = resolveTableName(tableName, tableIndex, allowQualifiedFallback); + if (resolvedTableName == null) { + throw new IllegalArgumentException( + "Table '%s' does not exist or is not visible in current datasource".formatted(tableName)); + } + resolvedTables.add(resolvedTableName); + } + return List.copyOf(resolvedTables); + } + + private TableResolutionIndex buildTableResolutionIndex(List tableNames) { + return new TableResolutionIndex(indexTableNames(tableNames, false), indexTableNames(tableNames, true)); + } + + private Map> indexTableNames(List tableNames, boolean leafOnly) { + Map> index = new LinkedHashMap<>(); + for (String tableName : Optional.ofNullable(tableNames).orElse(List.of())) { + if (tableName == null || tableName.isBlank()) { + continue; + } + String normalizedTableName = leafOnly ? normalizeLeafIdentifier(tableName) : normalizeIdentifier(tableName); + index.computeIfAbsent(normalizedTableName, key -> new LinkedHashSet<>()).add(tableName); + } + Map> immutableIndex = new LinkedHashMap<>(); + index.forEach((key, value) -> immutableIndex.put(key, List.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private String resolveTableName(String requestedTableName, TableResolutionIndex tableIndex, + boolean allowQualifiedFallback) { + String normalizedTableName = normalizeIdentifier(requestedTableName); + List exactMatches = tableIndex.exactTables().getOrDefault(normalizedTableName, List.of()); + if (exactMatches.size() == 1) { + return exactMatches.get(0); } + if (exactMatches.size() > 1) { + throw new IllegalArgumentException( + "Table '%s' maps to multiple datasource tables: %s".formatted(requestedTableName, exactMatches)); + } + if (isQualifiedIdentifier(requestedTableName) && !allowQualifiedFallback) { + return null; + } + List leafMatches = tableIndex.leafTables().getOrDefault(normalizeLeafIdentifier(requestedTableName), + List.of()); + if (leafMatches.size() == 1) { + return leafMatches.get(0); + } + if (leafMatches.size() > 1) { + throw new IllegalArgumentException("Table '%s' is ambiguous across datasource tables: %s" + .formatted(requestedTableName, leafMatches)); + } + return null; + } + + private boolean isQualifiedIdentifier(String value) { + return normalizeIdentifier(value).contains("."); + } + + private String normalizeIdentifier(String value) { + String normalized = Optional.ofNullable(value).orElse("").trim(); + normalized = stripWrapping(normalized, "`"); + normalized = stripWrapping(normalized, "\""); + normalized = stripWrapping(normalized, "[", "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeLeafIdentifier(String value) { + String normalized = normalizeIdentifier(value); + if (normalized.contains(".")) { + return normalized.substring(normalized.lastIndexOf('.') + 1); + } + return normalized; + } + + private String stripWrapping(String value, String wrapper) { + return stripWrapping(value, wrapper, wrapper); + } + + private String stripWrapping(String value, String prefix, String suffix) { + String normalized = value; + if (normalized.startsWith(prefix)) { + normalized = normalized.substring(prefix.length()); + } + if (normalized.endsWith(suffix)) { + normalized = normalized.substring(0, normalized.length() - suffix.length()); + } + return normalized; + } + + private record TableResolutionIndex(Map> exactTables, Map> leafTables) { } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java index 2512845dd..0c3c47bc7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.service.schema; +import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; import com.alibaba.cloud.ai.dataagent.bo.schema.ForeignKeyInfoBO; import com.alibaba.cloud.ai.dataagent.bo.schema.TableInfoBO; @@ -64,6 +65,8 @@ @AllArgsConstructor public class SchemaServiceImpl implements SchemaService { + private static final String FOREIGN_KEY_SEPARATOR = "、"; + private final ExecutorService dbOperationExecutor; private final AccessorFactory accessorFactory; @@ -168,6 +171,7 @@ public Boolean schema(Integer datasourceId, SchemaInitRequest schemaInitRequest) } log.info("Successfully processed all tables for datasource: {}", datasourceId); + applyVisibleColumnRestrictions(tables, schemaInitRequest.getVisibleColumnsByTable()); // 转换为文档 List columnDocs = convertColumnsToDocuments(datasourceId, tables); @@ -283,6 +287,123 @@ protected void clearSchemaDataForDatasource(Integer datasourceId) throws Excepti agentVectorStoreService.deleteDocumentsByMetadata(metadata); } + private void applyVisibleColumnRestrictions(List tables, Map> visibleColumnsByTable) { + Map> normalizedRestrictions = normalizeVisibleColumnRestrictions(visibleColumnsByTable); + if (normalizedRestrictions.isEmpty()) { + return; + } + for (TableInfoBO table : Optional.ofNullable(tables).orElse(List.of())) { + Set visibleColumns = resolveVisibleColumns(normalizedRestrictions, table.getName()); + if (visibleColumns == null) { + continue; + } + List filteredColumns = Optional.ofNullable(table.getColumns()) + .orElse(List.of()) + .stream() + .filter(column -> visibleColumns.contains(normalizeIdentifier(column.getName()))) + .toList(); + table.setColumns(filteredColumns); + List filteredPrimaryKeys = Optional.ofNullable(table.getPrimaryKeys()) + .orElse(List.of()) + .stream() + .filter(primaryKey -> visibleColumns.contains(normalizeIdentifier(primaryKey))) + .toList(); + table.setPrimaryKeys(filteredPrimaryKeys); + table.setForeignKey(filterForeignKeyText(table.getForeignKey(), normalizedRestrictions)); + } + } + + private Map> normalizeVisibleColumnRestrictions(Map> visibleColumnsByTable) { + Map> normalizedRestrictions = new LinkedHashMap<>(); + Optional.ofNullable(visibleColumnsByTable).orElse(Map.of()).forEach((tableName, columns) -> { + String normalizedTableName = normalizeIdentifier(tableName); + if (StringUtils.isBlank(normalizedTableName)) { + return; + } + Set normalizedColumns = Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .map(this::normalizeIdentifier) + .filter(StringUtils::isNotBlank) + .collect(Collectors.toCollection(LinkedHashSet::new)); + if (!normalizedColumns.isEmpty()) { + normalizedRestrictions.put(normalizedTableName, Set.copyOf(normalizedColumns)); + } + }); + return normalizedRestrictions; + } + + private String filterForeignKeyText(String foreignKeyText, Map> visibleColumnsByTable) { + if (StringUtils.isBlank(foreignKeyText)) { + return foreignKeyText; + } + return Arrays.stream(foreignKeyText.split(Pattern.quote(FOREIGN_KEY_SEPARATOR))) + .map(StringUtils::trimToEmpty) + .filter(StringUtils::isNotBlank) + .filter(relation -> isForeignKeyRelationVisible(relation, visibleColumnsByTable)) + .collect(Collectors.joining(FOREIGN_KEY_SEPARATOR)); + } + + private boolean isForeignKeyRelationVisible(String relation, Map> visibleColumnsByTable) { + String[] parts = StringUtils.splitByWholeSeparatorPreserveAllTokens(relation, "="); + if (parts == null || parts.length != 2) { + return false; + } + return isColumnReferenceVisible(parts[0], visibleColumnsByTable) + && isColumnReferenceVisible(parts[1], visibleColumnsByTable); + } + + private boolean isColumnReferenceVisible(String reference, Map> visibleColumnsByTable) { + String normalizedReference = normalizeIdentifier(reference); + int lastDot = normalizedReference.lastIndexOf('.'); + if (lastDot <= 0 || lastDot >= normalizedReference.length() - 1) { + return false; + } + String tableName = normalizedReference.substring(0, lastDot); + String columnName = normalizedReference.substring(lastDot + 1); + Set visibleColumns = resolveVisibleColumns(visibleColumnsByTable, tableName); + if (visibleColumns == null) { + return true; + } + return visibleColumns.contains(columnName); + } + + private Set resolveVisibleColumns(Map> visibleColumnsByTable, String tableName) { + String normalizedTableName = normalizeIdentifier(tableName); + Set exactMatch = visibleColumnsByTable.get(normalizedTableName); + if (exactMatch != null) { + return exactMatch; + } + String normalizedLeafTableName = normalizeLeafIdentifier(normalizedTableName); + List> leafMatches = visibleColumnsByTable.entrySet() + .stream() + .filter(entry -> normalizeLeafIdentifier(entry.getKey()).equals(normalizedLeafTableName)) + .map(Map.Entry::getValue) + .distinct() + .toList(); + if (leafMatches.size() == 1) { + return leafMatches.get(0); + } + return null; + } + + private String normalizeIdentifier(String value) { + String normalized = StringUtils.trimToEmpty(value); + normalized = StringUtils.removeStart(normalized, "`"); + normalized = StringUtils.removeEnd(normalized, "`"); + normalized = StringUtils.removeStart(normalized, "\""); + normalized = StringUtils.removeEnd(normalized, "\""); + normalized = StringUtils.removeStart(normalized, "["); + normalized = StringUtils.removeEnd(normalized, "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeLeafIdentifier(String value) { + String normalized = normalizeIdentifier(value); + int lastDot = normalized.lastIndexOf('.'); + return lastDot >= 0 ? normalized.substring(lastDot + 1) : normalized; + } + @Override public List getTableDocumentsByDatasource(Integer datasourceId, String query) { Assert.notNull(datasourceId, "datasourceId cannot be null"); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/SqlUtil.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/SqlUtil.java index bd297c59b..611500f4e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/SqlUtil.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/SqlUtil.java @@ -16,25 +16,17 @@ package com.alibaba.cloud.ai.dataagent.util; import com.alibaba.cloud.ai.dataagent.enums.BizDataSourceTypeEnum; +import java.util.Arrays; +import java.util.Locale; +import java.util.stream.Collectors; import lombok.experimental.UtilityClass; /** - * SQL 工具类 - * - * @author Yang Yufeng - * @version 1.0 + * SQL utilities. */ @UtilityClass public class SqlUtil { - /** - * 构建SELECT SQL语句 - * @param typeName 数据源类型 - * @param tableName 表名 - * @param columnNames 列名 - * @param limit 查询数量限制 - * @return SELECT SQL语句 - */ public static String buildSelectSql(String typeName, String tableName, String columnNames, int limit) { if (tableName == null || tableName.isEmpty()) { throw new IllegalArgumentException("Table name cannot be empty"); @@ -44,17 +36,42 @@ public static String buildSelectSql(String typeName, String tableName, String co } if (BizDataSourceTypeEnum.isSqlServerDialect(typeName)) { - // SQL Server 使用 TOP return String.format("SELECT TOP %d %s FROM %s", limit, columnNames, tableName); } - else if (BizDataSourceTypeEnum.isOracleDialect(typeName)) { - // Oracle 使用 FETCH FIRST (Oracle 12c+) + if (BizDataSourceTypeEnum.isOracleDialect(typeName)) { return String.format("SELECT %s FROM %s FETCH FIRST %d ROWS ONLY", columnNames, tableName, limit); } - else { - // MySQL, PostgreSQL, H2, SQLite 通用 LIMIT - return String.format("SELECT %s FROM %s LIMIT %d", columnNames, tableName, limit); + return String.format("SELECT %s FROM %s LIMIT %d", columnNames, tableName, limit); + } + + public static String quoteIdentifier(String typeName, String identifier) { + if (identifier == null || identifier.isBlank()) { + throw new IllegalArgumentException("Identifier cannot be empty"); + } + String trimmed = identifier.trim(); + if ("*".equals(trimmed)) { + return trimmed; + } + String normalizedType = typeName == null ? "" : typeName.toLowerCase(Locale.ROOT); + boolean mysqlLikeDialect = BizDataSourceTypeEnum.isMysqlDialect(normalizedType); + String quoteStart = mysqlLikeDialect ? "`" : "\""; + String quoteEnd = mysqlLikeDialect ? "`" : "\""; + return Arrays.stream(trimmed.split("\\.")) + .map(String::trim) + .filter(part -> !part.isEmpty()) + .map(part -> wrapIdentifierPart(part, quoteStart, quoteEnd)) + .collect(Collectors.joining(".")); + } + + private static String wrapIdentifierPart(String identifierPart, String quoteStart, String quoteEnd) { + String normalizedPart = identifierPart; + if ((normalizedPart.startsWith("`") && normalizedPart.endsWith("`")) + || (normalizedPart.startsWith("\"") && normalizedPart.endsWith("\"")) + || (normalizedPart.startsWith("[") && normalizedPart.endsWith("]"))) { + normalizedPart = normalizedPart.substring(1, normalizedPart.length() - 1); } + String escapedPart = normalizedPart.replace(quoteEnd, quoteEnd + quoteEnd); + return quoteStart + escapedPart + quoteEnd; } } diff --git a/data-agent-management/src/main/resources/sql/h2/schema-h2.sql b/data-agent-management/src/main/resources/sql/h2/schema-h2.sql index 45538f1cf..f81d122ea 100644 --- a/data-agent-management/src/main/resources/sql/h2/schema-h2.sql +++ b/data-agent-management/src/main/resources/sql/h2/schema-h2.sql @@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS agent_knowledge ( title VARCHAR(255) NOT NULL COMMENT '知识的标题 (用户定义, 用于在UI上展示和识别)', type VARCHAR(50) NOT NULL COMMENT '知识类型: DOCUMENT-文档, QA-问答, FAQ-常见问题', question TEXT COMMENT '问题 (仅当type为QA或FAQ时使用)', - content MEDIUMTEXT COMMENT '知识内容 (对于QA/FAQ是答案; 对于DOCUMENT, 此字段通常为空)', + content MEDIUMTEXT COMMENT '知识内容 (对于QA/FAQ是答案,对于DOCUMENT此字段通常为空)', is_recall INT DEFAULT 1 COMMENT '业务状态: 1=召回, 0=非召回', embedding_status VARCHAR(20) DEFAULT NULL COMMENT '向量化状态:PENDING待处理,PROCESSING处理中,COMPLETED已完成,FAILED失败', error_msg VARCHAR(255) DEFAULT NULL COMMENT '操作失败的错误信息', @@ -232,13 +232,28 @@ CREATE TABLE IF NOT EXISTS agent_datasource_tables table_name VARCHAR(255) NOT NULL COMMENT '数据表名', create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NULL COMMENT '创建时间', update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NULL COMMENT '更新时间', - CONSTRAINT uk_agent_datasource_tables_agent_datasource_id_table_name + CONSTRAINT uk_agent_ds_tables_ds_table UNIQUE (agent_datasource_id, table_name), - CONSTRAINT fk_agent_datasource_tables_agent_datasource_id + CONSTRAINT fk_agent_ds_tables_agent_ds FOREIGN KEY (agent_datasource_id) REFERENCES agent_datasource (id) ON UPDATE CASCADE ON DELETE CASCADE ) ENGINE = InnoDB COMMENT = '某个智能体某个数据源所选中的数据表'; +CREATE TABLE IF NOT EXISTS agent_datasource_columns +( + id INT AUTO_INCREMENT PRIMARY KEY, + agent_datasource_id INT NOT NULL COMMENT '智能体数据源ID', + table_name VARCHAR(255) NOT NULL COMMENT '数据表名', + column_name VARCHAR(255) NOT NULL COMMENT '字段名', + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NULL COMMENT '创建时间', + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NULL COMMENT '更新时间', + CONSTRAINT uk_agent_ds_cols_ds_table_col + UNIQUE (agent_datasource_id, table_name, column_name), + CONSTRAINT fk_agent_ds_cols_agent_ds + FOREIGN KEY (agent_datasource_id) REFERENCES agent_datasource (id) + ON UPDATE CASCADE ON DELETE CASCADE + ) ENGINE = InnoDB COMMENT = '某个智能体某个数据源所选中的字段白名单'; + -- 模型配置表 CREATE TABLE IF NOT EXISTS `model_config` ( diff --git a/data-agent-management/src/main/resources/sql/schema.sql b/data-agent-management/src/main/resources/sql/schema.sql index 7358a0c21..b254c7c8b 100644 --- a/data-agent-management/src/main/resources/sql/schema.sql +++ b/data-agent-management/src/main/resources/sql/schema.sql @@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS `agent_knowledge` ( `title` varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT '知识的标题 (用户定义, 用于在UI上展示和识别)', `type` varchar(50) COLLATE utf8mb4_bin NOT NULL COMMENT '知识类型: DOCUMENT-文档, QA-问答, FAQ-常见问题', `question` text COLLATE utf8mb4_bin COMMENT '问题 (仅当type为QA或FAQ时使用)', - `content` mediumtext COLLATE utf8mb4_bin COMMENT '知识内容 (对于QA/FAQ是答案; 对于DOCUMENT, 此字段通常为空)', + `content` mediumtext COLLATE utf8mb4_bin COMMENT '知识内容 (对于QA/FAQ是答案,对于DOCUMENT此字段通常为空)', `is_recall` int(11) DEFAULT 1 COMMENT '业务状态: 1=召回, 0=非召回', `embedding_status` varchar(20) COLLATE utf8mb4_bin DEFAULT NULL COMMENT '向量化状态:PENDING待处理,PROCESSING处理中,COMPLETED已完成,FAILED失败', `error_msg` varchar(255) COLLATE utf8mb4_bin DEFAULT NULL COMMENT '操作失败的错误信息', @@ -228,14 +228,30 @@ create table if not exists agent_datasource_tables table_name varchar(255) not null comment '数据表名', create_time timestamp default CURRENT_TIMESTAMP null comment '创建时间', update_time timestamp default CURRENT_TIMESTAMP null comment '更新时间', - constraint agent_datasource_tables_agent_datasource_id_table_name_uindex + constraint uk_agent_ds_tables_ds_table unique (agent_datasource_id, table_name), - constraint agent_datasource_tables_agent_datasource_id_fk + constraint fk_agent_ds_tables_agent_ds foreign key (agent_datasource_id) references agent_datasource (id) on update cascade on delete cascade ) comment '某个智能体某个数据源所选中的数据表'; +create table if not exists agent_datasource_columns +( + id int auto_increment primary key, + agent_datasource_id int not null comment '智能体数据源ID', + table_name varchar(255) not null comment '数据表名', + column_name varchar(255) not null comment '字段名', + create_time timestamp default CURRENT_TIMESTAMP null comment '创建时间', + update_time timestamp default CURRENT_TIMESTAMP null comment '更新时间', + constraint uk_agent_ds_cols_ds_table_col + unique (agent_datasource_id, table_name, column_name), + constraint fk_agent_ds_cols_agent_ds + foreign key (agent_datasource_id) references agent_datasource (id) + on update cascade on delete cascade +) + comment '某个智能体某个数据源所选中的字段白名单'; + -- 模型配置表 CREATE TABLE IF NOT EXISTS `model_config` ( diff --git a/data-agent-management/src/test/resources/sql/schema.sql b/data-agent-management/src/test/resources/sql/schema.sql index 7491a7d9d..c85241ab5 100644 --- a/data-agent-management/src/test/resources/sql/schema.sql +++ b/data-agent-management/src/test/resources/sql/schema.sql @@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS `agent_knowledge` ( `title` varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT '知识的标题 (用户定义, 用于在UI上展示和识别)', `type` varchar(50) COLLATE utf8mb4_bin NOT NULL COMMENT '知识类型: DOCUMENT-文档, QA-问答, FAQ-常见问题', `question` text COLLATE utf8mb4_bin COMMENT '问题 (仅当type为QA或FAQ时使用)', - `content` mediumtext COLLATE utf8mb4_bin COMMENT '知识内容 (对于QA/FAQ是答案; 对于DOCUMENT, 此字段通常为空)', + `content` mediumtext COLLATE utf8mb4_bin COMMENT '知识内容 (对于QA/FAQ是答案,对于DOCUMENT此字段通常为空)', `is_recall` int(11) DEFAULT 1 COMMENT '业务状态: 1=召回, 0=非召回', `embedding_status` varchar(20) COLLATE utf8mb4_bin DEFAULT NULL COMMENT '向量化状态:PENDING待处理,PROCESSING处理中,COMPLETED已完成,FAILED失败', `error_msg` varchar(255) COLLATE utf8mb4_bin DEFAULT NULL COMMENT '操作失败的错误信息', From bb5f293e5f935a1301455b192fd82722a593fec5 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Fri, 24 Apr 2026 17:58:43 +0800 Subject: [PATCH 06/22] feat: enhance trace and add enhance sqlcheck --- data-agent-frontend/src/views/AgentRun.vue | 1462 ++++++++++++++++- .../tool/sqlguard/SqlGuardCheckRequest.java | 23 +- .../tool/sqlguard/SqlGuardCheckResult.java | 43 +- .../tool/sqlguard/SqlGuardToolProvider.java | 56 +- .../sqlguard/SqlVerifyExplainService.java | 504 +++++- .../observability/SessionTraceStore.java | 40 +- .../src/main/resources/prompts/commonagent.md | 34 +- 7 files changed, 2029 insertions(+), 133 deletions(-) diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 0a4d8bea2..0ac684b0c 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -361,15 +361,35 @@ - +
- Trace ID: {{ sessionTrace.traceId }} - Span 数: {{ sessionTrace.spanCount }} - 耗时: {{ formatTraceDuration(sessionTrace.durationMs) }} - 开始时间: {{ formatTraceTime(sessionTrace.startEpochMs) }} + Trace ID: {{ sessionTrace.traceId }} + Span 数: {{ sessionTrace.spanCount }} + 耗时: {{ formatTraceDuration(sessionTrace.durationMs) }} + 开始时间: {{ formatTraceTime(sessionTrace.startEpochMs) }} + + Request: {{ sessionTrace.runtimeRequestId }} + + + Agent: {{ sessionTrace.agentId }} + +
+
+ + 刷新
- 刷新
-
-
-
- {{ row.span.name }} - {{ row.span.kind }} - - {{ row.span.status }} - - {{ formatTraceDuration(row.span.durationMs) }} +
+
+
+ Span 列表 + + {{ filteredTraceSpans.length }}/{{ flattenedTraceSpans.length }} +
-
- spanId: {{ row.span.spanId }} - parent: {{ row.span.parentSpanId || '-' }} - 开始: {{ formatTraceTime(row.span.startEpochMs) }} + +
+
-
- 属性 -
-
- {{ entry.key }} - {{ entry.value }} +
+ +
+ +
@@ -605,6 +761,8 @@ const traceDialogVisible = ref(false); const traceLoading = ref(false); const traceError = ref(''); + const traceSearchKeyword = ref(''); + const selectedTraceSpanId = ref(''); const sessionTrace = ref(null); // 监听NL2SQL开关变化 @@ -651,6 +809,8 @@ currentSession.value = session; sessionTrace.value = null; traceError.value = ''; + traceSearchKeyword.value = ''; + selectedTraceSpanId.value = ''; traceDialogVisible.value = false; try { @@ -1243,20 +1403,702 @@ sessionState.markdownReportContent = ''; }; + type TraceAttributeEntry = { key: string; value: string }; + type FlattenedTraceSpan = { + span: TraceSpan; + depth: number; + attributeEntries: TraceAttributeEntry[]; + searchableText: string; + }; + type TraceMessageKind = 'system' | 'user' | 'assistant' | 'tool-call' | 'tool-result' | 'other'; + type ParsedTraceMessage = { + id: string; + kind: TraceMessageKind; + label: string; + title: string; + content: string; + details: string; + skills: string[]; + }; + type ParsedTraceConversationGroup = { + attributeKey: string; + title: string; + messages: ParsedTraceMessage[]; + }; + + const TRACE_MESSAGE_CONTAINER_KEYS = [ + 'messagelist', + 'message_list', + 'messages', + 'history', + 'conversation', + 'input', + 'output', + 'prompt', + 'response', + ]; + const TRACE_SKILL_KEYS = ['skills', 'skillNames', 'skill_names', 'availableSkills']; + const TRACE_TOOL_RESULT_KEY_HINTS = [ + 'tool.call.result', + 'tool_call_result', + 'toolresponse', + 'tool_response', + 'tool.result', + 'function.output', + 'function.result', + ]; + const TRACE_TOOL_CALL_KEY_HINTS = [ + 'tool.call', + 'tool_call', + 'function.input', + 'function.arguments', + ]; + const TRACE_MESSAGE_KIND_PRIORITY: Record = { + 'tool-result': 6, + 'tool-call': 5, + system: 4, + assistant: 3, + user: 2, + other: 1, + }; + + const isTraceRecord = (value: unknown): value is Record => + value !== null && typeof value === 'object' && !Array.isArray(value); + + const normalizeTraceKey = (value: string) => value.trim().toLowerCase(); + + const normalizeTraceFingerprintText = (value: string) => + value + .trim() + .replace(/\s+/g, ' ') + .toLowerCase(); + + const tryParseTraceJson = (value: string): unknown | null => { + if (!isStructuredTraceValue(value)) { + return null; + } + try { + return JSON.parse(value); + } catch (error) { + return null; + } + }; + + const stringifyTracePayload = (value: unknown) => { + if (typeof value === 'string') { + return value; + } + if (value === null || value === undefined) { + return ''; + } + try { + return JSON.stringify(value, null, 2); + } catch (error) { + return String(value); + } + }; + + const sortTraceValue = (value: unknown): unknown => { + if (Array.isArray(value)) { + return value.map(item => sortTraceValue(item)); + } + if (!isTraceRecord(value)) { + return value; + } + return Object.keys(value) + .sort((left, right) => left.localeCompare(right)) + .reduce>((accumulator, key) => { + accumulator[key] = sortTraceValue(value[key]); + return accumulator; + }, {}); + }; + + const stringifyTraceSemanticPayload = (value: unknown) => { + const normalizedValue = unwrapTraceTextEnvelope(value); + if (typeof normalizedValue === 'string') { + const parsed = tryParseTraceJson(normalizedValue); + if (parsed !== null) { + return JSON.stringify(sortTraceValue(parsed)); + } + return normalizeTraceFingerprintText(normalizedValue); + } + if (normalizedValue === null || normalizedValue === undefined) { + return ''; + } + if (isTraceRecord(normalizedValue) || Array.isArray(normalizedValue)) { + return JSON.stringify(sortTraceValue(normalizedValue)); + } + return normalizeTraceFingerprintText(String(normalizedValue)); + }; + + const compactTraceObject = (value: Record) => { + return Object.fromEntries( + Object.entries(value).filter(([, item]) => { + if (item === null || item === undefined) { + return false; + } + if (typeof item === 'string') { + return item.trim().length > 0; + } + if (Array.isArray(item)) { + return item.length > 0; + } + if (isTraceRecord(item)) { + return Object.keys(item).length > 0; + } + return true; + }), + ); + }; + + const unwrapTraceTextEnvelope = (value: unknown): unknown => { + if (!isTraceRecord(value)) { + if (typeof value !== 'string') { + return value; + } + const parsed = tryParseTraceJson(value); + return parsed !== null ? unwrapTraceTextEnvelope(parsed) : value; + } + const keys = Object.keys(value); + const textValue = typeof value.text === 'string' ? value.text.trim() : ''; + const isTextEnvelope = + textValue.length > 0 && + keys.every(key => ['text', 'type', 'mimeType', 'metadata'].includes(key)); + if (!isTextEnvelope) { + return value; + } + const parsedText = tryParseTraceJson(textValue); + return parsedText !== null ? unwrapTraceTextEnvelope(parsedText) : textValue; + }; + + const unwrapTraceToolCallPayload = (value: Record) => { + if (isTraceRecord(value.param)) { + return value.param; + } + return value; + }; + + const resolveTraceToolTitle = ( + value: Record, + attributeKey: string, + kind: TraceMessageKind, + ) => { + const nameCandidate = + typeof value.name === 'string' + ? value.name + : typeof value.toolName === 'string' + ? value.toolName + : typeof value.function === 'string' + ? value.function + : typeof value.action === 'string' + ? value.action + : typeof value.messageType === 'string' + ? value.messageType + : ''; + if (nameCandidate.trim()) { + return nameCandidate.trim(); + } + return kind === 'tool-call' || kind === 'tool-result' ? '' : attributeKey; + }; + + const buildTraceToolCallContent = (value: Record) => { + const unwrappedValue = unwrapTraceToolCallPayload(value); + const preferredPayload = + unwrappedValue.input ?? + (isTraceRecord(unwrappedValue.metadata) ? unwrappedValue.metadata.arguments : undefined) ?? + unwrappedValue.arguments ?? + unwrappedValue.content ?? + unwrappedValue; + if (typeof preferredPayload === 'string') { + const parsed = tryParseTraceJson(preferredPayload); + return parsed !== null ? stringifyTracePayload(parsed) : preferredPayload; + } + return stringifyTracePayload(preferredPayload); + }; + + const buildTraceToolResultContent = (value: Record) => { + const preferredPayload = unwrapTraceTextEnvelope(value.output ?? value.content ?? value.result ?? value); + return stringifyTracePayload(preferredPayload); + }; + + const extractStringList = (value: unknown): string[] => { + if (typeof value === 'string') { + return value + .split(/[,,\n]/) + .map(item => item.trim()) + .filter(Boolean); + } + if (!Array.isArray(value)) { + return []; + } + return value + .map(item => { + if (typeof item === 'string') { + return item.trim(); + } + if (isTraceRecord(item)) { + const candidate = item.name ?? item.title ?? item.id ?? item.skillName; + return typeof candidate === 'string' ? candidate.trim() : ''; + } + return ''; + }) + .filter(Boolean); + }; + + const extractTraceSkills = (record: Record) => { + for (const key of TRACE_SKILL_KEYS) { + if (key in record) { + const skills = extractStringList(record[key]); + if (skills.length > 0) { + return skills; + } + } + } + const properties = isTraceRecord(record.properties) ? record.properties : null; + if (properties) { + for (const key of TRACE_SKILL_KEYS) { + if (key in properties) { + const skills = extractStringList(properties[key]); + if (skills.length > 0) { + return skills; + } + } + } + } + return []; + }; + + const extractTraceText = (value: unknown): string => { + if (typeof value === 'string') { + return value; + } + if (value === null || value === undefined) { + return ''; + } + if (Array.isArray(value)) { + const textParts = value + .map(item => extractTraceText(item)) + .map(item => item.trim()) + .filter(Boolean); + return textParts.join('\n'); + } + if (!isTraceRecord(value)) { + return String(value); + } + + const directKeys = ['content', 'text', 'message', 'prompt', 'instruction', 'instructions', 'result']; + for (const key of directKeys) { + if (key in value) { + const text = extractTraceText(value[key]); + if (text.trim()) { + return text; + } + } + } + + if (Array.isArray(value.content)) { + const textParts = value.content + .map(item => { + if (typeof item === 'string') { + return item; + } + if (isTraceRecord(item)) { + return extractTraceText(item.text ?? item.content ?? item.value ?? item.output); + } + return ''; + }) + .filter(Boolean); + if (textParts.length > 0) { + return textParts.join('\n'); + } + } + + return ''; + }; + + const inferTraceMessageKind = ( + record: Record, + fallbackKey = '', + ): TraceMessageKind => { + const normalizedFallbackKey = normalizeTraceKey(fallbackKey); + if ( + Array.isArray(record.toolCalls) || + Array.isArray(record.tool_calls) || + record.toolCall || + record.functionCall + ) { + return 'tool-call'; + } + if ( + Array.isArray(record.toolResponses) || + Array.isArray(record.tool_results) || + Array.isArray(record.responses) || + record.toolResponse || + record.toolResult + ) { + return 'tool-result'; + } + if (TRACE_TOOL_RESULT_KEY_HINTS.some(hint => normalizedFallbackKey.includes(hint))) { + return 'tool-result'; + } + if (TRACE_TOOL_CALL_KEY_HINTS.some(hint => normalizedFallbackKey.includes(hint))) { + return 'tool-call'; + } + + const roleCandidate = [record.role, record.type, record.messageType, record.name, fallbackKey] + .map(item => (typeof item === 'string' ? normalizeTraceKey(item) : '')) + .find(Boolean); + + if (!roleCandidate) { + return 'other'; + } + if (roleCandidate.includes('system')) { + return 'system'; + } + if (roleCandidate.includes('user') || roleCandidate.includes('human')) { + return 'user'; + } + if (roleCandidate.includes('assistant') || roleCandidate.includes('model')) { + return 'assistant'; + } + if (roleCandidate.includes('tool')) { + return 'tool-result'; + } + return 'other'; + }; + + const getTraceMessageFingerprint = (message: ParsedTraceMessage) => { + return [ + message.kind, + normalizeTraceFingerprintText(message.title), + normalizeTraceFingerprintText(message.content), + normalizeTraceFingerprintText(message.details), + message.skills.map(normalizeTraceFingerprintText).sort().join(','), + ].join('|'); + }; + + const getTraceMessageDedupFingerprint = (message: ParsedTraceMessage) => { + if (message.kind === 'tool-call' || message.kind === 'tool-result') { + return [ + message.kind, + stringifyTraceSemanticPayload(message.content), + stringifyTraceSemanticPayload(message.details), + ].join('|'); + } + return getTraceMessageFingerprint(message); + }; + + const rankTraceMessage = (message: ParsedTraceMessage) => { + return ( + TRACE_MESSAGE_KIND_PRIORITY[message.kind] * 1000 + + message.skills.length * 20 + + message.title.trim().length * 2 + + message.content.trim().length + ); + }; + + const getTraceConversationGroupFingerprint = (group: ParsedTraceConversationGroup) => { + return group.messages + .map((message, index) => `${index}:${getTraceMessageDedupFingerprint(message)}`) + .join('||'); + }; + + const rankTraceConversationGroup = (group: ParsedTraceConversationGroup) => { + const messageScore = group.messages.reduce((sum, message) => sum + rankTraceMessage(message), 0); + const attributeDepth = group.attributeKey.split('.').length; + return messageScore * 100 - attributeDepth * 10 - group.attributeKey.length; + }; + + const dedupeTraceConversationGroups = (groups: ParsedTraceConversationGroup[]) => { + const bestGroupByFingerprint = new Map(); + groups.forEach(group => { + const fingerprint = getTraceConversationGroupFingerprint(group); + const current = bestGroupByFingerprint.get(fingerprint); + if (!current || rankTraceConversationGroup(group) > rankTraceConversationGroup(current)) { + bestGroupByFingerprint.set(fingerprint, group); + } + }); + return Array.from(bestGroupByFingerprint.values()); + }; + + const traceMessageLabelMap: Record = { + system: 'SYSTEM', + user: 'USER', + assistant: 'ASSISTANT', + 'tool-call': 'TOOL CALL', + 'tool-result': 'TOOL RESULT', + other: 'OTHER', + }; + + const createParsedTraceMessage = ( + id: string, + kind: TraceMessageKind, + options: Partial, + ): ParsedTraceMessage => ({ + id, + kind, + label: traceMessageLabelMap[kind], + title: options.title ?? '', + content: options.content ?? '', + details: options.details ?? '', + skills: options.skills ?? [], + }); + + const normalizeTraceToolCallMessages = ( + value: unknown, + attributeKey: string, + path: string, + ): ParsedTraceMessage[] => { + const calls = Array.isArray(value) ? value : [value]; + return calls + .map((call, index) => { + if (!isTraceRecord(call)) { + return createParsedTraceMessage(`${attributeKey}-${path}-tool-call-${index}`, 'tool-call', { + content: stringifyTracePayload(call), + }); + } + const toolName = + typeof call.name === 'string' + ? call.name + : typeof call.toolName === 'string' + ? call.toolName + : typeof call.function === 'string' + ? call.function + : typeof call.id === 'string' + ? call.id + : '未命名工具'; + const callContent = extractTraceText(call.arguments ?? call.input ?? call.content) || stringifyTracePayload(call); + const detailSource = compactTraceObject({ + id: call.id, + type: call.type, + arguments: call.arguments, + input: call.input, + }); + return createParsedTraceMessage(`${attributeKey}-${path}-tool-call-${index}`, 'tool-call', { + title: toolName, + content: callContent, + details: stringifyTracePayload(detailSource), + }); + }) + .map(message => ({ + ...message, + details: message.details === '{}' ? '' : message.details, + })); + }; + + const normalizeTraceToolResultMessages = ( + value: unknown, + attributeKey: string, + path: string, + ): ParsedTraceMessage[] => { + const responses = Array.isArray(value) ? value : [value]; + return responses.map((response, index) => { + if (!isTraceRecord(response)) { + return createParsedTraceMessage(`${attributeKey}-${path}-tool-result-${index}`, 'tool-result', { + content: stringifyTracePayload(response), + }); + } + const toolName = + typeof response.name === 'string' + ? response.name + : typeof response.toolName === 'string' + ? response.toolName + : typeof response.id === 'string' + ? response.id + : '工具返回'; + const content = + extractTraceText(response.output ?? response.content ?? response.result) || + stringifyTracePayload(response.output ?? response.content ?? response.result ?? response); + const detailSource = compactTraceObject({ + id: response.id, + status: response.status, + error: response.error, + }); + return createParsedTraceMessage(`${attributeKey}-${path}-tool-result-${index}`, 'tool-result', { + title: toolName, + content, + details: stringifyTracePayload(detailSource) === '{}' ? '' : stringifyTracePayload(detailSource), + }); + }); + }; + + const normalizeTraceMessagesFromNode = ( + value: unknown, + attributeKey: string, + path: string, + ): ParsedTraceMessage[] => { + if (Array.isArray(value)) { + return value.flatMap((item, index) => + normalizeTraceMessagesFromNode(item, attributeKey, `${path}-${index}`), + ); + } + if (!isTraceRecord(value)) { + if (typeof value === 'string' && value.trim()) { + return [ + createParsedTraceMessage(`${attributeKey}-${path}-text`, 'other', { + content: value, + }), + ]; + } + return []; + } + + const messages: ParsedTraceMessage[] = []; + const kind = inferTraceMessageKind(value, attributeKey); + const content = extractTraceText(value); + const skills = extractTraceSkills(value); + const titleCandidate = + typeof value.name === 'string' + ? value.name + : typeof value.title === 'string' + ? value.title + : typeof value.messageType === 'string' + ? value.messageType + : ''; + + const toolCalls = value.toolCalls ?? value.tool_calls ?? value.toolCall ?? value.functionCall; + const toolResponses = + value.toolResponses ?? value.tool_results ?? value.responses ?? value.toolResponse ?? value.toolResult; + const properties = isTraceRecord(value.properties) ? value.properties : null; + const propertyDetails = properties ? stringifyTracePayload(properties) : ''; + + if (kind !== 'tool-call' && kind !== 'tool-result' && (content.trim() || skills.length > 0)) { + messages.push( + createParsedTraceMessage(`${attributeKey}-${path}-base`, kind, { + title: titleCandidate, + content, + details: propertyDetails === '{}' ? '' : propertyDetails, + skills, + }), + ); + } + + if (toolCalls) { + messages.push(...normalizeTraceToolCallMessages(toolCalls, attributeKey, path)); + } + if (toolResponses) { + messages.push(...normalizeTraceToolResultMessages(toolResponses, attributeKey, path)); + } + + if ( + messages.length === 0 && + (kind === 'tool-call' || kind === 'tool-result') && + (content.trim() || Object.keys(value).length > 0) + ) { + const fallbackContent = + kind === 'tool-call' + ? buildTraceToolCallContent(value) + : buildTraceToolResultContent(value); + messages.push( + createParsedTraceMessage(`${attributeKey}-${path}-tool-object`, kind, { + title: titleCandidate || resolveTraceToolTitle(value, attributeKey, kind), + content: fallbackContent, + details: propertyDetails === '{}' ? '' : propertyDetails, + skills, + }), + ); + } + + if (messages.length === 0 && content.trim()) { + messages.push( + createParsedTraceMessage(`${attributeKey}-${path}-fallback`, kind, { + title: titleCandidate, + content, + skills, + }), + ); + } + + return messages; + }; + + const extractTraceConversationGroupsFromEntry = ( + entry: TraceAttributeEntry, + ): ParsedTraceConversationGroup[] => { + const parsedValue = tryParseTraceJson(entry.value); + if (parsedValue === null) { + return []; + } + + const groups: ParsedTraceConversationGroup[] = []; + const pushGroup = (title: string, node: unknown, path: string) => { + const messages = normalizeTraceMessagesFromNode(node, entry.key, path).filter( + message => message.content.trim() || message.details.trim() || message.skills.length > 0, + ); + if (messages.length > 0) { + groups.push({ + attributeKey: path === entry.key ? entry.key : `${entry.key}.${path}`, + title, + messages, + }); + } + }; + + if (Array.isArray(parsedValue)) { + pushGroup(`${entry.key} · messages`, parsedValue, entry.key); + return groups; + } + + if (!isTraceRecord(parsedValue)) { + return groups; + } + + let matchedContainer = false; + Object.entries(parsedValue).forEach(([key, value]) => { + if (!TRACE_MESSAGE_CONTAINER_KEYS.includes(normalizeTraceKey(key))) { + return; + } + const nestedMessages = normalizeTraceMessagesFromNode(value, entry.key, key); + if (nestedMessages.length === 0) { + return; + } + groups.push({ + attributeKey: `${entry.key}.${key}`, + title: `${entry.key} · ${key}`, + messages: nestedMessages, + }); + matchedContainer = true; + }); + + if (!matchedContainer) { + pushGroup(`${entry.key} · message`, parsedValue, entry.key); + } + + return groups; + }; + + const buildTraceAttributeEntries = (attributes: Record): TraceAttributeEntry[] => { + return Object.entries(attributes ?? {}) + .sort(([leftKey], [rightKey]) => leftKey.localeCompare(rightKey)) + .map(([key, value]) => ({ + key, + value, + })); + }; + const flattenTraceSpans = ( spans: TraceSpan[], depth = 0, - ): Array<{ span: TraceSpan; depth: number; attributeEntries: Array<{ key: string; value: string }> }> => { + ): FlattenedTraceSpan[] => { return spans.flatMap(span => { - const attributeEntries = Object.entries(span.attributes ?? {}).map(([key, value]) => ({ - key, - value, - })); + const attributeEntries = buildTraceAttributeEntries(span.attributes ?? {}); + const searchableText = [ + span.name, + span.spanId, + span.parentSpanId, + span.kind, + span.status, + ...attributeEntries.flatMap(entry => [entry.key, entry.value]), + ] + .filter(Boolean) + .join(' ') + .toLowerCase(); return [ { span, depth, attributeEntries, + searchableText, }, ...flattenTraceSpans(span.children ?? [], depth + 1), ]; @@ -1267,6 +2109,70 @@ sessionTrace.value ? flattenTraceSpans(sessionTrace.value.rootSpans ?? []) : [], ); + const normalizedTraceSearchKeyword = computed(() => traceSearchKeyword.value.trim().toLowerCase()); + + const filteredTraceSpans = computed(() => { + const keyword = normalizedTraceSearchKeyword.value; + if (!keyword) { + return flattenedTraceSpans.value; + } + return flattenedTraceSpans.value.filter(row => row.searchableText.includes(keyword)); + }); + + const selectedTraceRow = computed(() => { + if (!flattenedTraceSpans.value.length) { + return null; + } + return ( + flattenedTraceSpans.value.find(row => row.span.spanId === selectedTraceSpanId.value) ?? + filteredTraceSpans.value[0] ?? + flattenedTraceSpans.value[0] + ); + }); + + const selectedTraceAttributeEntries = computed(() => { + const row = selectedTraceRow.value; + if (!row) { + return []; + } + const keyword = normalizedTraceSearchKeyword.value; + if (!keyword) { + return row.attributeEntries; + } + return row.attributeEntries.filter(entry => + `${entry.key} ${entry.value}`.toLowerCase().includes(keyword), + ); + }); + + const parsedTraceConversations = computed(() => { + const row = selectedTraceRow.value; + if (!row) { + return []; + } + const keyword = normalizedTraceSearchKeyword.value; + return dedupeTraceConversationGroups( + row.attributeEntries.flatMap(entry => extractTraceConversationGroupsFromEntry(entry)), + ) + .map(group => { + if (!keyword) { + return group; + } + return { + ...group, + messages: group.messages.filter(message => + `${message.label} ${message.title} ${message.content} ${message.details} ${message.skills.join(' ')}` + .toLowerCase() + .includes(keyword), + ), + }; + }) + .filter(group => group.messages.length > 0); + }); + + const selectTraceSpan = (spanId: string) => { + selectedTraceSpanId.value = spanId; + }; + const formatTraceDuration = (durationMs: number) => { if (durationMs < 1000) { return `${durationMs} ms`; @@ -1284,6 +2190,34 @@ return new Date(epochMs).toLocaleString(); }; + const formatTraceOffset = (epochMs: number) => { + if (!sessionTrace.value?.startEpochMs || !epochMs) { + return '--'; + } + const offset = Math.max(0, epochMs - sessionTrace.value.startEpochMs); + return `+${formatTraceDuration(offset)}`; + }; + + const isStructuredTraceValue = (value: string) => { + const normalizedValue = value?.trim(); + return Boolean( + normalizedValue && + ((normalizedValue.startsWith('{') && normalizedValue.endsWith('}')) || + (normalizedValue.startsWith('[') && normalizedValue.endsWith(']'))), + ); + }; + + const formatStructuredTraceValue = (value: string) => { + if (!isStructuredTraceValue(value)) { + return value; + } + try { + return JSON.stringify(JSON.parse(value), null, 2); + } catch (error) { + return value; + } + }; + const loadLatestTrace = async () => { if (!currentSession.value) { sessionTrace.value = null; @@ -1294,8 +2228,10 @@ traceError.value = ''; try { sessionTrace.value = await ChatService.getSessionTrace(currentSession.value.id); + selectedTraceSpanId.value = sessionTrace.value.rootSpans?.[0]?.spanId ?? ''; } catch (error: any) { sessionTrace.value = null; + selectedTraceSpanId.value = ''; if (error?.response?.status === 404) { traceError.value = '当前会话还没有最近一次 trace,请先执行一轮对话。'; } else { @@ -1573,9 +2509,18 @@ lastRequest, resultSetDisplayConfig, options, + traceSearchKeyword, flattenedTraceSpans, + filteredTraceSpans, + selectedTraceSpanId, + selectedTraceRow, + selectedTraceAttributeEntries, + parsedTraceConversations, formatTraceDuration, formatTraceTime, + formatTraceOffset, + isStructuredTraceValue, + formatStructuredTraceValue, getMarkdownContentFromNode, selectSession, sendMessage, @@ -1594,6 +2539,7 @@ stopStreaming, openTraceDialog, refreshTrace, + selectTraceSpan, deleteSessionState, }; }, @@ -2019,28 +2965,110 @@ display: flex; flex-wrap: wrap; gap: 12px; - color: #606266; - font-size: 13px; + } + + .trace-summary-pill { + display: inline-flex; + align-items: center; + min-height: 32px; + padding: 0 12px; + border: 1px solid #d9ecff; + border-radius: 999px; + background: linear-gradient(135deg, #f5fbff 0%, #eef6ff 100%); + color: #36658f; + font-size: 12px; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + } + + .trace-toolbar-actions { + display: flex; + align-items: center; + gap: 12px; + } + + .trace-search-input { + width: 280px; } .trace-alert { margin-bottom: 16px; } + .trace-explorer { + display: grid; + grid-template-columns: minmax(360px, 44%) minmax(420px, 1fr); + gap: 16px; + min-height: 62vh; + } + + .trace-pane { + border: 1px solid #e4ecf5; + border-radius: 18px; + background: linear-gradient(180deg, #fcfdff 0%, #f7faff 100%); + box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.8); + overflow: hidden; + } + + .trace-pane-header { + display: flex; + justify-content: space-between; + align-items: center; + gap: 12px; + padding: 16px 18px 12px; + border-bottom: 1px solid rgba(31, 94, 155, 0.08); + } + + .trace-pane-title { + font-size: 14px; + font-weight: 600; + color: #1f3b57; + } + + .trace-pane-count { + color: #6d7f92; + font-size: 12px; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + } + .trace-list { display: flex; flex-direction: column; - gap: 12px; - max-height: 65vh; + gap: 10px; + max-height: calc(62vh - 56px); overflow: auto; - padding-right: 6px; + padding: 14px; } .trace-row { - border: 1px solid #ebeef5; - border-radius: 12px; - padding: 12px; + appearance: none; + width: 100%; + border: 1px solid #e6edf5; + border-radius: 16px; + padding: 14px 16px; background: #fff; + text-align: left; + cursor: pointer; + transition: + border-color 0.2s ease, + transform 0.2s ease, + box-shadow 0.2s ease; + } + + .trace-row:hover { + border-color: #8ab8ff; + box-shadow: 0 10px 24px rgba(76, 115, 169, 0.12); + transform: translateY(-1px); + } + + .trace-row.is-selected { + border-color: #4b8dff; + box-shadow: 0 14px 28px rgba(75, 141, 255, 0.16); + background: linear-gradient(135deg, #ffffff 0%, #f3f8ff 100%); + } + + .trace-row.is-error { + border-color: #f3c1c1; + background: linear-gradient(135deg, #ffffff 0%, #fff7f7 100%); } .trace-row-main { @@ -2052,12 +3080,13 @@ .trace-row-name { font-weight: 600; - color: #303133; + color: #20354d; } .trace-row-duration { - color: #409eff; + color: #3d7cff; font-size: 12px; + font-family: 'JetBrains Mono', 'Fira Code', monospace; } .trace-row-meta { @@ -2070,37 +3099,288 @@ word-break: break-all; } - .trace-attributes { - margin-top: 10px; + .trace-pane-detail { + display: flex; + flex-direction: column; + min-height: 0; } - .trace-attributes summary { - cursor: pointer; - color: #606266; + .trace-detail-header { + display: flex; + justify-content: space-between; + align-items: flex-start; + gap: 16px; + padding: 18px 18px 12px; + border-bottom: 1px solid rgba(31, 94, 155, 0.08); + } + + .trace-detail-title { + font-size: 18px; + font-weight: 700; + color: #1b334a; + line-height: 1.4; + } + + .trace-detail-subtitle { + display: flex; + flex-wrap: wrap; + gap: 12px; + margin-top: 8px; + color: #728398; + font-size: 12px; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + } + + .trace-detail-tags { + display: flex; + gap: 8px; + flex-wrap: wrap; + } + + .trace-descriptions { + padding: 16px 18px 0; + } + + .trace-message-panel { + padding: 18px 18px 0; + } + + .trace-message-groups { + display: flex; + flex-direction: column; + gap: 14px; + margin-top: 12px; + } + + .trace-message-group { + border: 1px solid #e7edf5; + border-radius: 16px; + overflow: hidden; + background: #fff; + } + + .trace-message-group-header { + display: flex; + justify-content: space-between; + align-items: center; + gap: 12px; + padding: 12px 14px; + background: linear-gradient(180deg, #f8fbff 0%, #f2f7fc 100%); + border-bottom: 1px solid #edf2f8; + } + + .trace-message-group-title { + color: #20384f; font-size: 13px; + font-weight: 600; } - .trace-attributes-grid { + .trace-message-group-meta { + color: #78889c; + font-size: 12px; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + } + + .trace-message-list { + display: flex; + flex-direction: column; + gap: 12px; + padding: 14px; + } + + .trace-message-item { display: grid; - grid-template-columns: minmax(180px, 220px) minmax(0, 1fr); - gap: 8px 12px; - margin-top: 10px; + grid-template-columns: 92px minmax(0, 1fr); + gap: 12px; + align-items: start; + } + + .trace-message-role { + display: flex; + justify-content: center; + padding-top: 4px; + } + + .trace-message-role-badge { + display: inline-flex; + align-items: center; + justify-content: center; + min-width: 74px; + min-height: 28px; + padding: 0 10px; + border-radius: 999px; + font-size: 11px; + font-weight: 700; + letter-spacing: 0.04em; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + } + + .trace-message-item.is-system .trace-message-role-badge { + color: #7b4f00; + background: #fff3d8; } - .trace-attribute-item { - display: contents; + .trace-message-item.is-user .trace-message-role-badge { + color: #0e4db8; + background: #e7f0ff; + } + + .trace-message-item.is-assistant .trace-message-role-badge { + color: #166534; + background: #e8f7ed; + } + + .trace-message-item.is-tool-call .trace-message-role-badge { + color: #7c2d12; + background: #ffe8dc; + } + + .trace-message-item.is-tool-result .trace-message-role-badge { + color: #5b21b6; + background: #f2eaff; + } + + .trace-message-item.is-other .trace-message-role-badge { + color: #475569; + background: #e9eef5; + } + + .trace-message-body { + padding: 14px 16px; + border-radius: 16px; + border: 1px solid #e8edf4; + background: #fbfdff; + } + + .trace-message-item.is-user .trace-message-body { + background: linear-gradient(180deg, #f5f9ff 0%, #edf4ff 100%); + } + + .trace-message-item.is-assistant .trace-message-body { + background: linear-gradient(180deg, #fbfffc 0%, #f3fbf5 100%); + } + + .trace-message-item.is-system .trace-message-body { + background: linear-gradient(180deg, #fffdf8 0%, #fff8ea 100%); + } + + .trace-message-item.is-tool-call .trace-message-body { + background: linear-gradient(180deg, #fffaf7 0%, #fff1e8 100%); + } + + .trace-message-item.is-tool-result .trace-message-body { + background: linear-gradient(180deg, #fcf9ff 0%, #f5efff 100%); + } + + .trace-message-title { + color: #1d344b; + font-size: 13px; + font-weight: 600; + margin-bottom: 8px; + } + + .trace-message-skills { + display: flex; + flex-wrap: wrap; + gap: 8px; + margin-bottom: 10px; + } + + .trace-skill-chip { + display: inline-flex; + align-items: center; + padding: 0 10px; + min-height: 24px; + border-radius: 999px; + background: #edf4ff; + color: #335f94; + font-size: 11px; + font-weight: 600; + } + + .trace-message-content { + color: #26384b; + font-size: 12px; + line-height: 1.7; + white-space: pre-wrap; + word-break: break-word; + } + + .trace-message-content-structured, + .trace-message-details { + margin: 10px 0 0; + padding: 12px; + border-radius: 12px; + background: rgba(255, 255, 255, 0.72); + border: 1px solid #e7edf5; + overflow: auto; + font-size: 12px; + line-height: 1.6; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + white-space: pre-wrap; + word-break: break-word; + } + + .trace-attribute-panel { + padding: 18px; + min-height: 0; + display: flex; + flex-direction: column; + gap: 12px; + } + + .trace-attribute-table { + border: 1px solid #e7edf5; + border-radius: 14px; + overflow: hidden; + background: #fff; + } + + .trace-attribute-table-header, + .trace-attribute-row { + display: grid; + grid-template-columns: minmax(220px, 260px) minmax(0, 1fr); + } + + .trace-attribute-table-header { + background: #f5f8fc; + color: #60758b; + font-size: 12px; + font-weight: 600; + } + + .trace-attribute-table-header span, + .trace-attribute-key, + .trace-attribute-value { + padding: 12px 14px; + } + + .trace-attribute-row + .trace-attribute-row { + border-top: 1px solid #eef3f8; } .trace-attribute-key { - color: #606266; + color: #41576d; font-size: 12px; word-break: break-all; + background: #fbfcfe; + border-right: 1px solid #eef3f8; + font-family: 'JetBrains Mono', 'Fira Code', monospace; } .trace-attribute-value { - color: #303133; + color: #24384c; font-size: 12px; - word-break: break-all; + line-height: 1.6; + word-break: break-word; + white-space: pre-wrap; + } + + .trace-attribute-value-structured { + margin: 0; + overflow: auto; + font-family: 'JetBrains Mono', 'Fira Code', monospace; + background: #fbfdff; } @keyframes spin { @@ -2131,9 +3411,47 @@ align-items: stretch; } - .trace-attributes-grid { + .trace-toolbar-actions { + flex-direction: column; + align-items: stretch; + } + + .trace-search-input { + width: 100%; + } + + .trace-explorer { grid-template-columns: 1fr; } + + .trace-detail-header { + flex-direction: column; + align-items: stretch; + } + + .trace-message-item { + grid-template-columns: 1fr; + } + + .trace-message-role { + justify-content: flex-start; + padding-top: 0; + } + + .trace-message-group-header { + flex-direction: column; + align-items: flex-start; + } + + .trace-attribute-table-header, + .trace-attribute-row { + grid-template-columns: 1fr; + } + + .trace-attribute-key { + border-right: none; + border-bottom: 1px solid #eef3f8; + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java index 078c4f8c5..b2ce0d363 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java @@ -16,19 +16,32 @@ package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; import com.fasterxml.jackson.databind.JsonNode; +import java.util.List; import lombok.Data; +import org.apache.commons.lang3.StringUtils; @Data class SqlGuardCheckRequest { - private String query; + private String action; - private String sql; + private String query; - private JsonNode tableSchemas; + private String sql; - private JsonNode semanticHits; + private String tableName; - private JsonNode businessKnowledgeHits; + private List columnNames; + private Integer limit; + + private JsonNode tableSchemas; + + private JsonNode semanticHits; + + private JsonNode businessKnowledgeHits; + + String normalizedAction() { + return StringUtils.defaultIfBlank(action, "SQL_VERIFY").trim().toUpperCase(); + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java index 36a530b05..d0019fdd0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java @@ -22,36 +22,47 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; @Data @Builder @JsonInclude(JsonInclude.Include.NON_NULL) class SqlGuardCheckResult { - private String query; + private String action; - private String sql; + private String query; - private String summary; + private String sql; - private String explainedIntent; + private String tableName; - @JsonProperty("isAligned") - private boolean isAligned; + private String summary; - @Builder.Default - private List problems = new ArrayList<>(); + private String explainedIntent; - @Builder.Default - private List fixSuggestions = new ArrayList<>(); + @JsonProperty("isAligned") + private Boolean isAligned; - @Builder.Default - private List usedTables = new ArrayList<>(); + private Long totalRows; - @Builder.Default - private List usedMetrics = new ArrayList<>(); + private Integer inspectedColumnCount; - @Builder.Default - private List ruleChecks = new ArrayList<>(); + @Builder.Default + private List problems = new ArrayList<>(); + @Builder.Default + private List fixSuggestions = new ArrayList<>(); + + @Builder.Default + private List usedTables = new ArrayList<>(); + + @Builder.Default + private List usedMetrics = new ArrayList<>(); + + @Builder.Default + private List ruleChecks = new ArrayList<>(); + + @Builder.Default + private List> columnProfiles = new ArrayList<>(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java index 0c51e029d..9871385c7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -33,17 +33,37 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { { "type": "object", "properties": { + "action": { + "type": "string", + "enum": ["SQL_VERIFY", "DATA_PROFILE"], + "description": "可选。默认 SQL_VERIFY。SQL_VERIFY 用于候选 SQL 的结构与意图校验;DATA_PROFILE 用于查看字段值域、空值率、distinct、top values 与样例。" + }, "query": { "type": "string", - "description": "必填。用户原始问题。" + "description": "SQL_VERIFY 时必填。用户原始问题。" }, "sql": { "type": "string", - "description": "必填。当前准备执行或准备返回给用户的候选 SQL。" + "description": "SQL_VERIFY 时必填。当前准备执行或准备返回给用户的候选 SQL。" + }, + "tableName": { + "type": "string", + "description": "DATA_PROFILE 时必填。目标表名。" + }, + "columnNames": { + "type": "array", + "items": { + "type": "string" + }, + "description": "DATA_PROFILE 时可选。要分析的字段列表;不传时默认取该表前几个可见字段。" + }, + "limit": { + "type": "integer", + "description": "DATA_PROFILE 时可选。样例值和 top values 的返回上限,默认 5,最大 20。" }, "tableSchemas": { "type": "object", - "description": "可选。把 datasource explorer 的 schema 结果原样传入,帮助识别时间列、维度列与表关系。" + "description": "可选。把 datasource explorer 的 schema 结果原样传入,帮助 SQL 校验识别时间列、维度列与表关系。" }, "semanticHits": { "type": "object", @@ -53,17 +73,17 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { "type": "object", "description": "可选。把 domain_business_knowledge.search 的结果原样传入。" } - }, - "required": ["query", "sql"] + } } """; private static final String DESCRIPTION = """ - Single SQL verification tool for SQL-backed answers. - Check whether the candidate SQL really matches the user's intent before execution or final answer. - If verification fails, read isAligned=false plus problems, ruleChecks and fixSuggestions, then rewrite SQL yourself and call sql_guard.check again. - Each problem explains why it is wrong, what was expected, what was actually detected and how to repair it. - Always pass a fresh top-level query and sql. Do not pass previous sql_guard.check output back into the tool. + Unified SQL guard tool for SQL-backed answers. + Action SQL_VERIFY: check whether the candidate SQL really matches the user's intent before execution or final answer. + Action DATA_PROFILE: inspect column value distribution before writing SQL when field semantics are unclear. + For SQL_VERIFY, if verification fails, read isAligned=false plus problems, ruleChecks and fixSuggestions, then rewrite SQL yourself and call sql_guard.check again. + For DATA_PROFILE, use the returned columnProfiles to understand null ratio, distinct count, top values, samples, and whether a field looks categorical, numeric, or temporal. + Always pass fresh top-level parameters for the current action. Do not pass previous sql_guard.check output back into the tool. """; private final ObjectMapper objectMapper; @@ -82,19 +102,23 @@ public Map getToolCallbacks(String agentId) { .description(DESCRIPTION) .inputSchema(INPUT_SCHEMA) .build(); - return Map.of(TOOL_NAME, new SqlGuardToolCallback(toolDefinition, objectMapper, sqlVerifyExplainService)); + return Map.of(TOOL_NAME, + new SqlGuardToolCallback(agentId, toolDefinition, objectMapper, sqlVerifyExplainService)); } private static final class SqlGuardToolCallback implements ToolCallback { + private final String agentId; + private final ToolDefinition toolDefinition; private final ObjectMapper objectMapper; private final SqlVerifyExplainService sqlVerifyExplainService; - private SqlGuardToolCallback(ToolDefinition toolDefinition, ObjectMapper objectMapper, + private SqlGuardToolCallback(String agentId, ToolDefinition toolDefinition, ObjectMapper objectMapper, SqlVerifyExplainService sqlVerifyExplainService) { + this.agentId = agentId; this.toolDefinition = toolDefinition; this.objectMapper = objectMapper; this.sqlVerifyExplainService = sqlVerifyExplainService; @@ -110,7 +134,13 @@ public String call(String toolInput) { try { SqlGuardCheckRequest request = StringUtils.hasText(toolInput) ? objectMapper.readValue(toolInput, SqlGuardCheckRequest.class) : new SqlGuardCheckRequest(); - return objectMapper.writeValueAsString(sqlVerifyExplainService.explain(request)); + String action = request.normalizedAction(); + SqlGuardCheckResult result = switch (action) { + case "DATA_PROFILE" -> sqlVerifyExplainService.inspectProfile(agentId, request); + case "SQL_VERIFY" -> sqlVerifyExplainService.explain(request); + default -> throw new IllegalArgumentException("Unsupported sql_guard.check action: " + action); + }; + return objectMapper.writeValueAsString(result); } catch (Exception ex) { throw new IllegalStateException("Failed to execute sql_guard.check: " + ex.getMessage(), ex); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java index b21ce7276..75ce284d3 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -15,12 +15,26 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; +import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ResultSetBO; +import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; +import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; +import com.alibaba.cloud.ai.dataagent.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; +import com.alibaba.cloud.ai.dataagent.entity.Datasource; +import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; +import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; +import com.alibaba.cloud.ai.dataagent.util.SqlUtil; import com.fasterxml.jackson.databind.JsonNode; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -34,6 +48,16 @@ @Service public class SqlVerifyExplainService { + private static final String ACTION_SQL_VERIFY = "SQL_VERIFY"; + + private static final String ACTION_DATA_PROFILE = "DATA_PROFILE"; + + private static final int DEFAULT_PROFILE_LIMIT = 5; + + private static final int MAX_PROFILE_LIMIT = 20; + + private static final int DEFAULT_PROFILE_COLUMN_COUNT = 3; + private static final Pattern AGGREGATE_PATTERN = Pattern .compile("(?i)\\b(count|sum|avg|average|min|max)\\s*\\(([^)]*)\\)\\s*(?:as\\s+([a-zA-Z0-9_]+))?"); @@ -72,6 +96,19 @@ public class SqlVerifyExplainService { private static final Pattern SQL_FETCH_FIRST_VALUE_PATTERN = Pattern .compile("(?is)\\bfetch\\s+first\\s+(\\d+)\\s+rows\\s+only\\b"); + private final AgentDatasourceService agentDatasourceService; + + private final DatasourceService datasourceService; + + private final AccessorFactory accessorFactory; + + public SqlVerifyExplainService(AgentDatasourceService agentDatasourceService, DatasourceService datasourceService, + AccessorFactory accessorFactory) { + this.agentDatasourceService = agentDatasourceService; + this.datasourceService = datasourceService; + this.accessorFactory = accessorFactory; + } + public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { String query = StringUtils.trimToEmpty(request == null ? null : request.getQuery()); String sql = StringUtils.trimToEmpty(request == null ? null : request.getSql()); @@ -88,6 +125,7 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { } catch (IllegalArgumentException ex) { return SqlGuardCheckResult.builder() + .action(ACTION_SQL_VERIFY) .query(query) .sql(sql) .isAligned(false) @@ -138,6 +176,7 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { fixSuggestions.add("当前规则校验通过;如要进一步提高置信度,可继续核对执行结果与最终答案解释。"); } return SqlGuardCheckResult.builder() + .action(ACTION_SQL_VERIFY) .query(query) .sql(sql) .isAligned(aligned) @@ -151,6 +190,457 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { .build(); } + public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest request) { + String tableName = StringUtils.trimToEmpty(request == null ? null : request.getTableName()); + if (StringUtils.isBlank(tableName)) { + throw new IllegalArgumentException("sql_guard.check with action=DATA_PROFILE requires tableName"); + } + ProfileContext context = resolveProfileContext(agentId); + String actualTableName = resolveVisibleTableName(context, tableName); + List availableColumns = loadTableColumns(context, actualTableName); + List visibleColumns = applyVisibleColumnRestrictions(context, actualTableName, availableColumns); + if (visibleColumns.isEmpty()) { + throw new IllegalArgumentException( + "Table '%s' has no visible columns for current agent".formatted(actualTableName)); + } + List columnsToInspect = resolveColumnsToInspect(request, actualTableName, visibleColumns); + int sampleLimit = normalizeProfileLimit(request == null ? null : request.getLimit()); + long totalRows = querySingleLong(context, "SELECT COUNT(*) AS total_rows FROM " + quoteTable(context, actualTableName), + "total_rows"); + List> columnProfiles = columnsToInspect.stream() + .map(column -> buildColumnProfile(context, actualTableName, column, totalRows, sampleLimit)) + .toList(); + String summary = "Profiled %d columns from '%s' using visible columns only." + .formatted(columnProfiles.size(), actualTableName); + return SqlGuardCheckResult.builder() + .action(ACTION_DATA_PROFILE) + .query(request == null ? null : request.getQuery()) + .tableName(actualTableName) + .summary(summary) + .totalRows(totalRows) + .inspectedColumnCount(columnProfiles.size()) + .usedTables(List.of(actualTableName)) + .columnProfiles(columnProfiles) + .fixSuggestions(List.of( + "Use categorical fields with concentrated topValues as filters or GROUP BY candidates.", + "Use numeric/date fields with min/max ranges as metric, trend, or time-window candidates.")) + .build(); + } + + private ProfileContext resolveProfileContext(String agentId) { + if (!StringUtils.isNumeric(agentId)) { + throw new IllegalArgumentException("sql_guard.check DATA_PROFILE only supports numeric agentId"); + } + Long numericAgentId = Long.valueOf(agentId); + AgentDatasource agentDatasource = agentDatasourceService.getCurrentAgentDatasource(numericAgentId); + Datasource datasource = agentDatasource.getDatasource() != null ? agentDatasource.getDatasource() + : datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); + if (datasource == null) { + throw new IllegalStateException("Active datasource not found for agent " + agentId); + } + DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); + Accessor accessor = accessorFactory.getAccessorByDbConfig(dbConfig); + List explicitSelectedTables = Optional.ofNullable(agentDatasource.getSelectTables()).orElse(List.of()); + List visibleTables; + try { + visibleTables = explicitSelectedTables.isEmpty() ? datasourceService.getDatasourceTables(datasource.getId()) + : explicitSelectedTables; + } + catch (Exception ex) { + throw new IllegalStateException( + "Failed to load visible tables for datasource %s: %s".formatted(datasource.getId(), ex.getMessage()), + ex); + } + Map> visibleTablesByName = indexTables(visibleTables, false); + Map> visibleTablesByLeafName = indexTables(visibleTables, true); + Map> visibleColumnsByTable = buildVisibleColumnsByTable(agentDatasource, visibleTablesByName, + visibleTablesByLeafName); + Map> visibleColumnNameSetByTable = new LinkedHashMap<>(); + visibleColumnsByTable.forEach((key, value) -> visibleColumnNameSetByTable.put(key, + value.stream() + .map(this::normalizeColumnName) + .collect(LinkedHashSet::new, LinkedHashSet::add, LinkedHashSet::addAll))); + return new ProfileContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), + Map.copyOf(visibleTablesByName), Map.copyOf(visibleTablesByLeafName), Map.copyOf(visibleColumnsByTable), + Map.copyOf(visibleColumnNameSetByTable), Set.copyOf(visibleColumnsByTable.keySet())); + } + + private List loadTableColumns(ProfileContext context, String tableName) { + try { + return Optional.ofNullable(context.accessor() + .showColumns(context.dbConfig(), + DbQueryParameter.from(context.dbConfig()).setSchema(context.dbConfig().getSchema()).setTable(tableName))) + .orElse(List.of()); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to load columns for table '%s': %s".formatted(tableName, ex.getMessage()), + ex); + } + } + + private List applyVisibleColumnRestrictions(ProfileContext context, String tableName, + List columns) { + return Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(column -> isColumnVisible(context, tableName, column.getName())) + .toList(); + } + + private List resolveColumnsToInspect(SqlGuardCheckRequest request, String tableName, + List visibleColumns) { + Map columnsByName = new LinkedHashMap<>(); + for (ColumnInfoBO column : visibleColumns) { + columnsByName.put(normalizeColumnName(column.getName()), column); + } + List requestedColumns = Optional.ofNullable(request == null ? null : request.getColumnNames()) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .toList(); + if (requestedColumns.isEmpty()) { + return visibleColumns.stream().limit(DEFAULT_PROFILE_COLUMN_COUNT).toList(); + } + List resolvedColumns = new ArrayList<>(); + for (String requestedColumn : requestedColumns) { + ColumnInfoBO column = columnsByName.get(normalizeColumnName(requestedColumn)); + if (column == null) { + throw new IllegalArgumentException( + "Column '%s' is not visible in table '%s' for current agent".formatted(requestedColumn, tableName)); + } + resolvedColumns.add(column); + } + return resolvedColumns; + } + + private Map buildColumnProfile(ProfileContext context, String tableName, ColumnInfoBO column, + long totalRows, int sampleLimit) { + String quotedTable = quoteTable(context, tableName); + String quotedColumn = SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), column.getName()); + long nullCount = querySingleLong(context, + "SELECT COUNT(*) AS null_rows FROM %s WHERE %s IS NULL".formatted(quotedTable, quotedColumn), "null_rows"); + Double nullRatio = totalRows <= 0 ? 0D : roundRatio((double) nullCount / (double) totalRows); + Long distinctCount = null; + if (supportsDistinctCount(column)) { + distinctCount = querySingleLong(context, "SELECT COUNT(DISTINCT %s) AS distinct_count FROM %s" + .formatted(quotedColumn, quotedTable), "distinct_count"); + } + List> topValues = supportsGroupedTopValues(column) + ? queryTopValues(context, quotedTable, quotedColumn, sampleLimit) : List.of(); + List sampleValues = querySampleValues(context, quotedTable, quotedColumn, sampleLimit, supportsDistinctCount(column)); + String minValue = supportsMinMax(column) + ? querySingleValue(context, + "SELECT MIN(%s) AS min_value FROM %s".formatted(quotedColumn, quotedTable), "min_value") + : null; + String maxValue = supportsMinMax(column) + ? querySingleValue(context, + "SELECT MAX(%s) AS max_value FROM %s".formatted(quotedColumn, quotedTable), "max_value") + : null; + Map profile = new LinkedHashMap<>(); + profile.put("columnName", column.getName()); + profile.put("dataType", column.getType()); + profile.put("notNull", column.isNotnull()); + profile.put("nullCount", nullCount); + profile.put("nullRatio", nullRatio); + profile.put("distinctCount", distinctCount); + profile.put("sampleValues", sampleValues); + profile.put("topValues", topValues); + profile.put("min", minValue); + profile.put("max", maxValue); + profile.put("profileHints", buildProfileHints(column, nullRatio, distinctCount, totalRows, topValues)); + return profile; + } + + private List> queryTopValues(ProfileContext context, String quotedTable, String quotedColumn, + int sampleLimit) { + String sql = applyLimit(""" + SELECT %s AS profile_value, COUNT(*) AS profile_count + FROM %s + WHERE %s IS NOT NULL + GROUP BY %s + ORDER BY profile_count DESC + """.formatted(quotedColumn, quotedTable, quotedColumn, quotedColumn), context.dbConfig().getDialectType(), + sampleLimit); + ResultSetBO resultSet = executeSql(context, sql); + List> values = new ArrayList<>(); + for (Map row : Optional.ofNullable(resultSet.getData()).orElse(List.of())) { + Map entry = new LinkedHashMap<>(); + entry.put("value", row.get("profile_value")); + entry.put("count", parseLong(row.get("profile_count"))); + values.add(entry); + } + return values; + } + + private List querySampleValues(ProfileContext context, String quotedTable, String quotedColumn, int sampleLimit, + boolean distinctPreferred) { + String selectClause = distinctPreferred ? "SELECT DISTINCT %s AS sample_value".formatted(quotedColumn) + : "SELECT %s AS sample_value".formatted(quotedColumn); + String sql = applyLimit(""" + %s + FROM %s + WHERE %s IS NOT NULL + ORDER BY %s + """.formatted(selectClause, quotedTable, quotedColumn, quotedColumn), context.dbConfig().getDialectType(), + sampleLimit); + ResultSetBO resultSet = executeSql(context, sql); + return Optional.ofNullable(resultSet.getData()) + .orElse(List.of()) + .stream() + .map(row -> row.get("sample_value")) + .filter(StringUtils::isNotBlank) + .toList(); + } + + private List buildProfileHints(ColumnInfoBO column, Double nullRatio, Long distinctCount, long totalRows, + List> topValues) { + List hints = new ArrayList<>(); + if (Boolean.TRUE.equals(isLikelyCategorical(column, distinctCount, totalRows, topValues))) { + hints.add("Likely categorical field; suitable for filter or GROUP BY."); + } + if (supportsMinMax(column)) { + hints.add("Likely ordered field; suitable for range filter, metric, or trend axis."); + } + if (nullRatio != null && nullRatio >= 0.5D) { + hints.add("High null ratio; be careful when using it as a hard filter."); + } + if (hints.isEmpty()) { + hints.add("Inspect samples and top values before deciding whether to use it in SQL."); + } + return hints; + } + + private Boolean isLikelyCategorical(ColumnInfoBO column, Long distinctCount, long totalRows, + List> topValues) { + if (!supportsGroupedTopValues(column)) { + return false; + } + if (distinctCount != null && distinctCount > 0 && distinctCount <= 20) { + return true; + } + if (totalRows > 0 && distinctCount != null && distinctCount <= Math.max(10, totalRows / 10)) { + return true; + } + return !topValues.isEmpty() && topValues.size() <= 10; + } + + private long querySingleLong(ProfileContext context, String sql, String columnName) { + return parseLong(querySingleValue(context, sql, columnName)); + } + + private String querySingleValue(ProfileContext context, String sql, String columnName) { + ResultSetBO resultSet = executeSql(context, sql); + List> rows = Optional.ofNullable(resultSet.getData()).orElse(List.of()); + if (rows.isEmpty()) { + return null; + } + Map row = rows.get(0); + if (row.containsKey(columnName)) { + return row.get(columnName); + } + return row.values().stream().findFirst().orElse(null); + } + + private ResultSetBO executeSql(ProfileContext context, String sql) { + try { + ResultSetBO resultSet = context.accessor().executeSqlAndReturnObject(context.dbConfig(), + DbQueryParameter.from(context.dbConfig()).setSchema(context.dbConfig().getSchema()).setSql(sql)); + if (resultSet == null) { + return ResultSetBO.builder().column(List.of()).data(List.of()).build(); + } + if (StringUtils.isNotBlank(resultSet.getErrorMsg())) { + throw new IllegalStateException(resultSet.getErrorMsg()); + } + if (resultSet.getColumn() == null) { + resultSet.setColumn(List.of()); + } + if (resultSet.getData() == null) { + resultSet.setData(List.of()); + } + return resultSet; + } + catch (Exception ex) { + throw new IllegalStateException("Failed to execute profile SQL: " + ex.getMessage(), ex); + } + } + + private int normalizeProfileLimit(Integer requestedLimit) { + if (requestedLimit == null || requestedLimit <= 0) { + return DEFAULT_PROFILE_LIMIT; + } + return Math.min(requestedLimit, MAX_PROFILE_LIMIT); + } + + private boolean supportsDistinctCount(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return !containsAny(normalizedType, "blob", "clob", "text", "ntext", "image", "json", "xml", "bytea"); + } + + private boolean supportsGroupedTopValues(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return !containsAny(normalizedType, "blob", "clob", "ntext", "image", "bytea"); + } + + private boolean supportsMinMax(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return containsAny(normalizedType, "int", "number", "numeric", "decimal", "double", "float", "real", "date", + "time", "year", "timestamp"); + } + + private String normalizeType(ColumnInfoBO column) { + return StringUtils.defaultString(column == null ? null : column.getType()).toLowerCase(Locale.ROOT); + } + + private String applyLimit(String sql, String dialectType, int limit) { + String trimmed = StringUtils.trimToEmpty(sql); + if (trimmed.isEmpty()) { + return trimmed; + } + String normalizedDialect = StringUtils.defaultString(dialectType).toLowerCase(Locale.ROOT); + if (normalizedDialect.contains("sqlserver") || normalizedDialect.contains("sql_server")) { + if (trimmed.matches("(?is)^select\\s+distinct\\b.*")) { + return trimmed.replaceFirst("(?is)^select\\s+distinct\\b", "SELECT DISTINCT TOP %d".formatted(limit)); + } + return trimmed.replaceFirst("(?is)^select\\b", "SELECT TOP %d".formatted(limit)); + } + if (normalizedDialect.contains("oracle")) { + return trimmed + " FETCH FIRST " + limit + " ROWS ONLY"; + } + return trimmed + " LIMIT " + limit; + } + + private String quoteTable(ProfileContext context, String tableName) { + return SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), tableName); + } + + private double roundRatio(double value) { + return Math.round(value * 10000D) / 10000D; + } + + private long parseLong(String value) { + if (StringUtils.isBlank(value)) { + return 0L; + } + try { + return Long.parseLong(value.trim()); + } + catch (NumberFormatException ex) { + try { + return Math.round(Double.parseDouble(value.trim())); + } + catch (NumberFormatException ignored) { + return 0L; + } + } + } + + private String resolveVisibleTableName(ProfileContext context, String tableName) { + return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) + .orElseThrow(() -> new IllegalArgumentException( + "Table '%s' is not visible for current agent. Visible tables: %s".formatted(tableName, + String.join(", ", context.visibleTables())))); + } + + private Optional findVisibleTableName(Map> visibleTablesByName, + Map> visibleTablesByLeafName, String tableName, boolean allowQualifiedFallback) { + String normalizedTableName = normalizeIdentifier(tableName); + List exactMatches = visibleTablesByName.getOrDefault(normalizedTableName, List.of()); + if (exactMatches.size() == 1) { + return Optional.of(exactMatches.get(0)); + } + if (exactMatches.size() > 1) { + throw new IllegalArgumentException( + "Table '%s' maps to multiple visible tables: %s".formatted(tableName, String.join(", ", exactMatches))); + } + if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { + return Optional.empty(); + } + List leafMatches = visibleTablesByLeafName.getOrDefault(normalizeTableLeafName(tableName), List.of()); + if (leafMatches.size() == 1) { + return Optional.of(leafMatches.get(0)); + } + if (leafMatches.size() > 1) { + throw new IllegalArgumentException("Table '%s' is ambiguous across visible tables: %s" + .formatted(tableName, String.join(", ", leafMatches))); + } + return Optional.empty(); + } + + private Map> buildVisibleColumnsByTable(AgentDatasource agentDatasource, + Map> visibleTablesByName, Map> visibleTablesByLeafName) { + Map> selectedColumns = Optional.ofNullable(agentDatasource.getSelectColumns()).orElse(Map.of()); + Map> visibleColumnsByTable = new LinkedHashMap<>(); + selectedColumns.forEach((tableName, columns) -> { + Optional resolvedTableName = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + tableName, true); + if (resolvedTableName.isEmpty()) { + return; + } + List sanitizedColumns = Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + if (!sanitizedColumns.isEmpty()) { + visibleColumnsByTable.put(normalizeTableName(resolvedTableName.get()), sanitizedColumns); + } + }); + return visibleColumnsByTable; + } + + private Map> indexTables(List tableNames, boolean leafOnly) { + Map> index = new LinkedHashMap<>(); + for (String tableName : Optional.ofNullable(tableNames).orElse(List.of())) { + if (StringUtils.isBlank(tableName)) { + continue; + } + String key = leafOnly ? normalizeTableLeafName(tableName) : normalizeTableName(tableName); + index.computeIfAbsent(key, ignored -> new ArrayList<>()).add(tableName); + } + return index; + } + + private boolean isColumnVisible(ProfileContext context, String tableName, String columnName) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return true; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + return visibleColumns != null && visibleColumns.contains(normalizeColumnName(columnName)); + } + + private boolean isQualifiedIdentifier(String value) { + return normalizeIdentifier(value).contains("."); + } + + private String normalizeIdentifier(String value) { + String normalized = StringUtils.trimToEmpty(value); + normalized = StringUtils.removeStart(normalized, "`"); + normalized = StringUtils.removeEnd(normalized, "`"); + normalized = StringUtils.removeStart(normalized, "\""); + normalized = StringUtils.removeEnd(normalized, "\""); + normalized = StringUtils.removeStart(normalized, "["); + normalized = StringUtils.removeEnd(normalized, "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeTableName(String tableName) { + return normalizeIdentifier(tableName); + } + + private String normalizeTableLeafName(String tableName) { + String normalized = normalizeIdentifier(tableName); + int lastDot = normalized.lastIndexOf('.'); + return lastDot >= 0 ? normalized.substring(lastDot + 1) : normalized; + } + + private String normalizeColumnName(String columnName) { + return normalizeTableLeafName(columnName); + } + private void evaluateAggregationRule(String query, String sql, QueryIntent intent, SqlShape shape, List problems, Set fixSuggestions, List ruleChecks) { if (!intent.requiresAggregation()) { @@ -698,14 +1188,20 @@ private Integer extractFirstInt(Pattern pattern, String value) { } } + private record ProfileContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, + Accessor accessor, List visibleTables, Map> visibleTablesByName, + Map> visibleTablesByLeafName, Map> visibleColumnsByTable, + Map> visibleColumnNameSetByTable, Set columnRestrictedTables) { + } + private record QueryIntent(boolean requiresAggregation, boolean requiresGrouping, boolean requiresTimeFilter, - boolean requiresOrdering, boolean requiresLimit, boolean requiresDistinct, boolean requiresTrend, - boolean prefersDescending, boolean prefersAscending, Integer expectedLimit) { + boolean requiresOrdering, boolean requiresLimit, boolean requiresDistinct, boolean requiresTrend, + boolean prefersDescending, boolean prefersAscending, Integer expectedLimit) { } private record SqlShape(List usedTables, List usedMetrics, boolean hasAggregation, boolean hasGroupBy, - boolean hasOrderBy, boolean hasLimit, boolean hasDistinct, boolean hasTimePredicate, boolean hasTimeBucket, - boolean orderDescending, boolean orderDirectionKnown, boolean limitValueKnown, Integer limitValue) { + boolean hasOrderBy, boolean hasLimit, boolean hasDistinct, boolean hasTimePredicate, boolean hasTimeBucket, + boolean orderDescending, boolean orderDirectionKnown, boolean limitValueKnown, Integer limitValue) { } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java index 96dfa13ae..cf941c030 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java @@ -24,10 +24,10 @@ import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; import lombok.Getter; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; @@ -51,14 +51,14 @@ public class SessionTraceStore implements SpanExporter { private static final int MAX_SESSION_TRACES = 128; - private static final int MAX_ATTRIBUTE_VALUE_LENGTH = 256; + private static final int MAX_ATTRIBUTE_VALUE_LENGTH = 2048; private static final String META_OMITTED_ATTRIBUTE_COUNT = "_meta.omitted_attribute_count"; - private static final Set SAFE_ATTRIBUTE_KEYS = Set.of(ATTR_THREAD_ID, ATTR_RUNTIME_REQUEST_ID, ATTR_AGENT_ID, - "dataagent.runtime.human_feedback", "dataagent.runtime.nl2sql_only", "data_agent.agent_id", - "data_agent.thread_id", "data_agent.nl2sql_only", "data_agent.human_feedback", - "gen_ai.usage.prompt_tokens", "gen_ai.usage.completion_tokens", "gen_ai.usage.total_tokens", "error.type"); + private static final List SENSITIVE_ATTRIBUTE_KEY_TOKENS = List.of("password", "passwd", "pwd", "secret", + "api_key", "apikey", "api_token", "apitoken", "access_key", "access_token", "accesstoken", + "refresh_token", "refreshtoken", "id_token", "idtoken", "auth_token", "authtoken", "private_key", + "privatekey", "authorization", "cookie", "credential", "signature"); private final Object monitor = new Object(); @@ -167,32 +167,46 @@ private SpanView toSpanView(SpanData span) { private Map sanitizeAttributes(SpanData span) { Map attributes = new LinkedHashMap<>(); - int omittedAttributeCount = 0; span.getAttributes().forEach((key, value) -> { String attributeKey = key.getKey(); - if (!SAFE_ATTRIBUTE_KEYS.contains(attributeKey)) { - return; - } - attributes.put(attributeKey, sanitizeAttributeValue(String.valueOf(value))); + attributes.put(attributeKey, sanitizeAttributeValue(attributeKey, String.valueOf(value))); }); - omittedAttributeCount = Math.max(0, span.getTotalAttributeCount() - attributes.size()); + int omittedAttributeCount = Math.max(0, span.getTotalAttributeCount() - span.getAttributes().size()); if (omittedAttributeCount > 0) { attributes.put(META_OMITTED_ATTRIBUTE_COUNT, String.valueOf(omittedAttributeCount)); } return attributes; } - private String sanitizeAttributeValue(String value) { + private String sanitizeAttributeValue(String attributeKey, String value) { if (!StringUtils.hasText(value)) { return value; } String normalizedValue = value.replace('\r', ' ').replace('\n', ' ').trim(); + if (isSensitiveAttributeKey(attributeKey)) { + return maskAttributeValue(normalizedValue); + } if (normalizedValue.length() <= MAX_ATTRIBUTE_VALUE_LENGTH) { return normalizedValue; } return normalizedValue.substring(0, MAX_ATTRIBUTE_VALUE_LENGTH) + "..."; } + private boolean isSensitiveAttributeKey(String attributeKey) { + if (!StringUtils.hasText(attributeKey)) { + return false; + } + String normalizedKey = attributeKey.toLowerCase(Locale.ROOT).replace('-', '_').replace('.', '_'); + return SENSITIVE_ATTRIBUTE_KEY_TOKENS.stream().anyMatch(normalizedKey::contains); + } + + private String maskAttributeValue(String value) { + if (value.length() <= 8) { + return "***"; + } + return value.substring(0, 4) + "..." + value.substring(value.length() - 4); + } + private static final class TraceAssembly { private String sessionId; diff --git a/data-agent-management/src/main/resources/prompts/commonagent.md b/data-agent-management/src/main/resources/prompts/commonagent.md index df781219e..4da3a6e60 100644 --- a/data-agent-management/src/main/resources/prompts/commonagent.md +++ b/data-agent-management/src/main/resources/prompts/commonagent.md @@ -2,19 +2,33 @@ 1. 只要问题属于数据库物理结构探索,就优先使用 datasource explorer 工具。 包括:找表、看列、看字段类型、看数据预览、执行只读 SQL、查看物理表关系。 + 2. 只有当数据库本身不能直接表达某些表或列的补充语义时,才使用 `semantic_model.search`。 包括:别名、业务友好名称、枚举含义、字段说明、使用备注、补充性关系提示。 + 3. 如果问题属于业务定义、指标口径、SOP、FAQ、历史案例、领域术语,而不是表结构本身,就使用 `domain_business_knowledge.search`。 + 4. 如果用户问的是表名、列名、字段类型、枚举值、表关系、字段关系,不要先调用 `domain_business_knowledge.search`。 先检查 datasource explorer;如果 datasource explorer 还不够,再考虑 `semantic_model.search`。 -5. 如果你已经准备了候选 SQL,且答案将基于 SQL 返回给用户,在执行 SQL 前先调用一次 `sql_guard.check`。 - 必须传顶层字段:`query`、`sql`;可选再传 `tableSchemas`、`semanticHits`、`businessKnowledgeHits`。 -6. `sql_guard.check` 只做结构与意图校验,不负责自动修复、不负责执行报错修复、也不负责结果回看。 - 重点校验:指标是否对题、是否缺少 `GROUP BY`、时间窗口是否完整、排序 / TopN 是否正确、是否遗漏 `DISTINCT`。 -7. 读取 `sql_guard.check` 结果时,直接看顶层字段:`isAligned`、`problems`、`ruleChecks`、`fixSuggestions`、`summary`。 - `problems` 里会给出为什么错、期望是什么、实际检测到什么、建议怎么修;`ruleChecks` 用来解释这次到底检查了哪些规则、每条规则是通过还是失败。 -8. 如果 `sql_guard.check` 返回 `isAligned=false`,必须根据 `problems` 和 `fixSuggestions` 自己改写 SQL,然后把新的候选 SQL 再次传给 `sql_guard.check`。 + +5. 如果字段语义不稳、你不确定某列是不是枚举列/状态列/时间列/数值列,先调用 `sql_guard.check`,传 `action=DATA_PROFILE`。 + 必传:`tableName`。 + 可选:`columnNames`、`limit`。 + 重点读取:`columnProfiles`、`totalRows`、`summary`。 + +6. 如果你已经准备了候选 SQL,且答案将基于 SQL 返回给用户,在执行 SQL 前先调用 `sql_guard.check`,传 `action=SQL_VERIFY`。 + 必传:`query`、`sql`。 + 可选:`tableSchemas`、`semanticHits`、`businessKnowledgeHits`。 + +7. `sql_guard.check` 是统一 SQL 工具。 + `action=SQL_VERIFY`:只做结构与意图校验,不负责自动修复、不负责执行报错修复、也不负责结果回看。 + `action=DATA_PROFILE`:只做字段值域分析,帮助你决定过滤条件、GROUP BY、时间窗口和指标写法。 + +8. 读取 `action=SQL_VERIFY` 结果时,直接看顶层字段:`isAligned`、`problems`、`ruleChecks`、`fixSuggestions`、`summary`。 + `problems` 会说明为什么错、期望是什么、实际检测到什么、建议怎么修。 + +9. 如果 `action=SQL_VERIFY` 返回 `isAligned=false`,必须根据 `problems` 和 `fixSuggestions` 自己改写 SQL,然后把新的候选 SQL 再次传给 `sql_guard.check`。 不要把上一次 `sql_guard.check` 的输出对象原样回传给工具;每次都要重新传顶层 `query` 和新的 `sql`。 -9. 只有当 `sql_guard.check` 返回 `isAligned=true` 后,才能执行 datasource explorer 的 `SEARCH`。 -10. 如果 SQL 执行报错,或者执行结果看起来不合理,由 agent 根据数据库报错或结果样例自行分析并重写 SQL,再重新走 `sql_guard.check`。 - 不要调用额外的 SQL 自动修复工具。 + +10. 只有当 `action=SQL_VERIFY` 返回 `isAligned=true` 后,才能执行 datasource explorer 的 `SEARCH`。 + 如果 SQL 执行报错,或者执行结果看起来不合理,由 agent 根据数据库报错或结果样例自行分析并重写 SQL,再重新走 `sql_guard.check`。 From a94902b789fbdba0eb3d6ea4789cf4fe4f139421 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Fri, 24 Apr 2026 23:22:09 +0800 Subject: [PATCH 07/22] feat: enhance trace and add execution analysis --- .claude/code-review.md | 328 +++ .../src/components/agent/DataSourceConfig.vue | 76 +- .../src/components/run/HumanFeedback.vue | 133 -- data-agent-frontend/src/services/chat.ts | 56 + data-agent-frontend/src/services/graph.ts | 11 +- data-agent-frontend/src/services/resultSet.ts | 8 - .../src/services/sessionStateManager.ts | 119 +- data-agent-frontend/src/views/AgentRun.vue | 2085 +++++++++-------- .../runtime/AgentRuntimeExtensionFactory.java | 2 +- .../runtime/AgentRuntimeRequestMetadata.java | 3 +- .../SpringToolCallbackAgentAdapter.java | 11 + .../runtime/ToolContextRequestResolver.java | 72 + .../impl/AiAgentRuntimeServiceImpl.java | 75 +- .../datasource/DatasourceExplorerResult.java | 14 + .../datasource/DatasourceExplorerService.java | 230 +- .../DatasourceExplorerToolProvider.java | 35 +- .../DomainBusinessKnowledgeToolSupport.java | 15 +- .../tool/semantic/SemanticModelSearchHit.java | 2 + .../semantic/SemanticModelSearchService.java | 26 +- .../semantic/SemanticModelToolSupport.java | 15 +- .../dataagent/controller/ChatController.java | 41 + .../controller/DataAgentController.java | 5 +- .../dataagent/mapper/ChatMessageMapper.java | 11 +- .../AnswerTraceExplainStore.java | 503 ++++ .../service/chat/ChatMessageService.java | 5 + .../service/chat/ChatMessageServiceImpl.java | 5 + .../DomainKnowledgeSearchService.java | 5 + .../DomainKnowledgeSearchServiceImpl.java | 22 +- 28 files changed, 2734 insertions(+), 1179 deletions(-) create mode 100644 .claude/code-review.md delete mode 100644 data-agent-frontend/src/components/run/HumanFeedback.vue create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java diff --git a/.claude/code-review.md b/.claude/code-review.md new file mode 100644 index 000000000..e56073d9f --- /dev/null +++ b/.claude/code-review.md @@ -0,0 +1,328 @@ +# 代码审查报告 + +## 📋 概览 + +本次提交主要涉及以下改动: +1. **前端**:Trace 对话框优化、数据来源持久化、UI 样式改进 +2. **后端**:工具调用增强、trace 数据结构优化 + +--- + +## ✅ 优点 + +### 1. Trace 对话框优化 +- ✅ **树形展示**:增加了缩进量(32px),层级关系更清晰 +- ✅ **视觉优化**:添加了渐变背景、阴影效果、hover 动画 +- ✅ **消息去重**:实现了智能去重逻辑,避免重复显示工具调用 +- ✅ **截断处理**:对被截断的 JSON 数据进行了容错处理 + +### 2. 数据来源持久化 +- ✅ **sessionStorage 持久化**:刷新页面后数据来源可以恢复 +- ✅ **会话隔离**:每个会话的状态独立保存 +- ✅ **状态同步**:切换会话时正确保存和恢复状态 + +### 3. 代码质量 +- ✅ **类型安全**:使用 TypeScript 接口定义状态结构 +- ✅ **错误处理**:添加了 try-catch 保护 +- ✅ **注释完善**:关键函数都有中文注释 + +--- + +## ⚠️ 问题与建议 + +### 🔴 严重问题 + +#### 1. **parsedTraceConversations 逻辑过于简化** +**位置**:`AgentRun.vue:2461-2530` + +**问题**: +```typescript +// 只处理 agentscope.function.input 和 agentscope.function.output +const inputEntry = row.attributeEntries.find(e => e.key === 'agentscope.function.input'); +const outputEntry = row.attributeEntries.find(e => e.key === 'agentscope.function.output'); + +if (inputEntry && outputEntry) { + // ... 创建消息 + return [{ ... }]; +} + +// 如果没有找到 agentscope.function 属性,返回空 +return []; +``` + +**影响**: +- ❌ 如果 span 中没有 `agentscope.function.input/output`,将不显示任何消息 +- ❌ 其他类型的消息(如 `gen_ai.*`、`tool.*` 等)会被完全忽略 +- ❌ 用户可能看到空白的 trace 详情面板 + +**建议**: +```typescript +// 优先处理 agentscope.function,但如果不存在则回退到默认逻辑 +if (inputEntry && outputEntry) { + // ... 处理 agentscope.function + return [{ ... }]; +} + +// 回退到默认逻辑 +const keyword = normalizedTraceSearchKeyword.value; +return dedupeTraceConversationGroups( + row.attributeEntries.flatMap(entry => extractTraceConversationGroupsFromEntry(entry)), +) + .map(group => { ... }) + .filter(group => group.messages.length > 0); +``` + +--- + +#### 2. **删除了人工反馈功能但没有清理相关代码** +**位置**:`AgentRun.vue` + +**问题**: +- ❌ 删除了 `HumanFeedback` 组件的引用和使用 +- ❌ 但保留了 `handleNl2sqlOnlyChange` 函数(已无用) +- ❌ 保留了 `showHumanFeedback`、`lastRequest` 等变量声明(未使用) + +**建议**: +完全清理人工反馈相关代码: +```typescript +// 删除这些未使用的变量和函数 +const handleNl2sqlOnlyChange = (value: boolean) => { ... }; // 删除 +const showHumanFeedback = ref(false); // 删除 +const lastRequest = ref(null); // 删除 +``` + +--- + +### 🟡 中等问题 + +#### 3. **sessionStorage 可能超出存储限制** +**位置**:`sessionStateManager.ts:66-82` + +**问题**: +```typescript +function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { + try { + const persistable: PersistableState = { + nodeBlocks: state.nodeBlocks, // 可能很大 + answerExplain: state.answerExplain, // 可能很大 + // ... + }; + sessionStorage.setItem(key, JSON.stringify(persistable)); + } catch (error) { + console.error('保存会话状态失败:', error); + } +} +``` + +**影响**: +- ⚠️ `nodeBlocks` 和 `answerExplain` 可能包含大量数据 +- ⚠️ sessionStorage 通常限制为 5-10MB +- ⚠️ 超出限制时会静默失败,用户不知道状态未保存 + +**建议**: +1. 添加存储大小检查 +2. 只保存必要的字段 +3. 对大数据进行压缩或截断 +```typescript +function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { + try { + const persistable: PersistableState = { + // 只保存最后 10 个 nodeBlocks + nodeBlocks: state.nodeBlocks.slice(-10), + answerExplain: state.answerExplain, + // ... + }; + + const json = JSON.stringify(persistable); + const sizeInMB = new Blob([json]).size / (1024 * 1024); + + if (sizeInMB > 4) { + console.warn(`会话状态过大 (${sizeInMB.toFixed(2)}MB),跳过保存`); + return; + } + + sessionStorage.setItem(key, json); + } catch (error) { + console.error('保存会话状态失败:', error); + // 提示用户 + ElMessage.warning('会话状态保存失败,刷新页面后可能丢失部分数据'); + } +} +``` + +--- + +#### 4. **去重逻辑可能过于激进** +**位置**:`AgentRun.vue:2050-2060` + +**问题**: +```typescript +const getTraceMessageDedupFingerprint = (message: ParsedTraceMessage) => { + if (message.kind === 'tool-call' || message.kind === 'tool-result') { + // 不包含 title,只比较 content 和 details + return [ + message.kind, + stringifyTraceSemanticPayload(message.content), + stringifyTraceSemanticPayload(message.details), + ].join('|'); + } + return getTraceMessageFingerprint(message); +}; +``` + +**影响**: +- ⚠️ 如果两个不同的工具调用参数相同,会被认为是重复的 +- ⚠️ 例如:连续两次调用 `GET_TABLE_SCHEMA` 查询同一张表 + +**建议**: +考虑添加时间戳或调用 ID 到指纹中: +```typescript +const getTraceMessageDedupFingerprint = (message: ParsedTraceMessage) => { + if (message.kind === 'tool-call' || message.kind === 'tool-result') { + return [ + message.kind, + message.id, // 添加消息 ID + stringifyTraceSemanticPayload(message.content), + stringifyTraceSemanticPayload(message.details), + ].join('|'); + } + return getTraceMessageFingerprint(message); +}; +``` + +--- + +#### 5. **CSS 样式过于复杂** +**位置**:`AgentRun.vue` 样式部分 + +**问题**: +- ⚠️ 大量的渐变、阴影、动画效果 +- ⚠️ 可能影响性能,特别是在大量元素时 +- ⚠️ 维护成本高 + +**建议**: +1. 考虑使用 CSS 变量统一管理颜色和尺寸 +2. 减少不必要的渐变和阴影 +3. 使用 `will-change` 优化动画性能 + +```css +:root { + --trace-primary-color: #409eff; + --trace-border-color: #e8f1fa; + --trace-hover-shadow: 0 8px 24px rgba(64, 158, 255, 0.15); +} + +.trace-row { + border: 2px solid var(--trace-border-color); + transition: all 0.3s ease; + will-change: transform, box-shadow; +} + +.trace-row:hover { + box-shadow: var(--trace-hover-shadow); + transform: translateY(-2px); +} +``` + +--- + +### 🟢 轻微问题 + +#### 6. **缺少加载状态的用户反馈** +**位置**:`AgentRun.vue:1690-1713` + +**问题**: +```typescript +const loadAnswerExplainByRuntimeRequestId = async (runtimeRequestId: string) => { + answerExplainVisible.value = true; + answerExplainLoading.value = true; + // ... 加载数据 +}; +``` + +**建议**: +添加加载提示: +```typescript +const loadAnswerExplainByRuntimeRequestId = async (runtimeRequestId: string) => { + answerExplainVisible.value = true; + answerExplainLoading.value = true; + + const loadingMessage = ElMessage.info({ + message: '正在加载数据来源...', + duration: 0, + }); + + try { + // ... 加载数据 + } finally { + loadingMessage.close(); + answerExplainLoading.value = false; + } +}; +``` + +--- + +#### 7. **Magic Numbers** +**位置**:多处 + +**问题**: +```typescript +:style="{ paddingLeft: `${row.depth * 32 + 20}px` }" // 32 和 20 是什么? +if (outputEntry.value.length > 10000) { ... } // 10000 是什么? +``` + +**建议**: +使用常量: +```typescript +const TRACE_INDENT_SIZE = 32; +const TRACE_BASE_PADDING = 20; +const MAX_OUTPUT_LENGTH = 10000; + +:style="{ paddingLeft: `${row.depth * TRACE_INDENT_SIZE + TRACE_BASE_PADDING}px` }" +if (outputEntry.value.length > MAX_OUTPUT_LENGTH) { ... } +``` + +--- + +## 📊 统计 + +- **修改文件数**:24 个 +- **前端文件**:5 个 +- **后端文件**:19 个 +- **新增文件**:2 个 +- **删除文件**:1 个 + +--- + +## 🎯 总体评价 + +**评分**:7.5/10 + +**优点**: +- ✅ 功能实现完整,UI 优化明显 +- ✅ 代码结构清晰,类型安全 +- ✅ 持久化方案合理 + +**需要改进**: +- ❌ `parsedTraceConversations` 逻辑过于简化,需要添加回退逻辑 +- ⚠️ sessionStorage 存储大小需要控制 +- ⚠️ 清理未使用的代码 + +--- + +## 🔧 建议的修复优先级 + +1. **高优先级**:修复 `parsedTraceConversations` 的回退逻辑 +2. **中优先级**:添加 sessionStorage 大小限制 +3. **低优先级**:清理未使用的代码、优化 CSS + +--- + +## ✅ 可以提交吗? + +**建议**:⚠️ **修复高优先级问题后再提交** + +**理由**: +- `parsedTraceConversations` 的问题可能导致部分 trace 无法显示 +- 其他问题不影响核心功能,可以后续优化 diff --git a/data-agent-frontend/src/components/agent/DataSourceConfig.vue b/data-agent-frontend/src/components/agent/DataSourceConfig.vue index 83c65a980..fb5e4711d 100644 --- a/data-agent-frontend/src/components/agent/DataSourceConfig.vue +++ b/data-agent-frontend/src/components/agent/DataSourceConfig.vue @@ -591,7 +591,9 @@
未加载到字段信息。 @@ -624,7 +626,9 @@ > - >((result, [tableName, columns]) => { + selectedColumns.value[item.datasource.id] = Object.entries( + item.selectColumns || {}, + ).reduce>((result, [tableName, columns]) => { result[tableName] = [...columns]; return result; }, {}); @@ -1051,7 +1055,9 @@ } }; - const getAgentDatasourceByDatasourceId = (datasourceId: number): AgentDatasource | undefined => { + const getAgentDatasourceByDatasourceId = ( + datasourceId: number, + ): AgentDatasource | undefined => { return agentDatasourceList.value.find(item => item.datasource?.id === datasourceId); }; @@ -1071,13 +1077,12 @@ const nextSnapshot: AgentDatasource = { ...snapshot, selectTables: [...(snapshot.selectTables || [])], - selectColumns: Object.entries(snapshot.selectColumns || {}).reduce>( - (result, [tableName, columns]) => { - result[tableName] = [...columns]; - return result; - }, - {}, - ), + selectColumns: Object.entries(snapshot.selectColumns || {}).reduce< + Record + >((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}), }; const agentDatasourceIndex = agentDatasourceList.value.findIndex( @@ -1101,9 +1106,9 @@ } selectedTables.value[datasourceId] = [...(nextSnapshot.selectTables || [])]; - selectedColumns.value[datasourceId] = Object.entries(nextSnapshot.selectColumns || {}).reduce< - Record - >((result, [tableName, columns]) => { + selectedColumns.value[datasourceId] = Object.entries( + nextSnapshot.selectColumns || {}, + ).reduce>((result, [tableName, columns]) => { result[tableName] = [...columns]; return result; }, {}); @@ -1160,7 +1165,10 @@ return `${datasourceId}:${tableName}`; }; - const loadColumnsForTable = async (datasourceId: number, tableName: string): Promise => { + const loadColumnsForTable = async ( + datasourceId: number, + tableName: string, + ): Promise => { const loadingKey = getColumnLoadingKey(datasourceId, tableName); columnLoadingStates.value[loadingKey] = true; try { @@ -1550,10 +1558,13 @@ updateLoadingStates.value[datasource.id] = true; try { - const response = await agentDatasourceService.updateDatasourceTables(String(props.agentId), { - datasourceId: datasource.id, - tables: selectedTables.value[datasource.id] || [], - }); + const response = await agentDatasourceService.updateDatasourceTables( + String(props.agentId), + { + datasourceId: datasource.id, + tables: selectedTables.value[datasource.id] || [], + }, + ); if (response.success && response.data) { applyAgentDatasourceSnapshot(response.data); @@ -1613,12 +1624,18 @@ const agentDatasource = getAgentDatasourceByDatasourceId(datasourceRow.id); tables.forEach(tableName => { - const configuredColumns = resolveConfiguredColumns(agentDatasource?.selectColumns, tableName); + const configuredColumns = resolveConfiguredColumns( + agentDatasource?.selectColumns, + tableName, + ); selectedColumns.value[datasourceRow.id][tableName] = configuredColumns; - columnRestrictionEnabled.value[datasourceRow.id][tableName] = configuredColumns.length > 0; + columnRestrictionEnabled.value[datasourceRow.id][tableName] = + configuredColumns.length > 0; }); - await Promise.all(tables.map(tableName => loadColumnsForTable(datasourceRow.id!, tableName))); + await Promise.all( + tables.map(tableName => loadColumnsForTable(datasourceRow.id!, tableName)), + ); columnDialogVisible.value = true; }; @@ -1666,10 +1683,13 @@ columns: [...(selectedColumns.value[datasourceId]?.[tableName] || [])], })); - const response = await agentDatasourceService.updateDatasourceColumns(String(props.agentId), { - datasourceId, - tables, - }); + const response = await agentDatasourceService.updateDatasourceColumns( + String(props.agentId), + { + datasourceId, + tables, + }, + ); if (response.success && response.data) { applyAgentDatasourceSnapshot(response.data); diff --git a/data-agent-frontend/src/components/run/HumanFeedback.vue b/data-agent-frontend/src/components/run/HumanFeedback.vue deleted file mode 100644 index e8e081f2d..000000000 --- a/data-agent-frontend/src/components/run/HumanFeedback.vue +++ /dev/null @@ -1,133 +0,0 @@ - - - - - - - diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index b3bf8c14c..ff014c772 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -39,6 +39,55 @@ export interface ChatMessage { titleNeeded?: boolean; } +export interface AnswerTraceSemanticHit { + tableName?: string; + columnName?: string; + businessName?: string; + businessDescription?: string; + matchedBy?: string; + score?: number; + relationHint?: string; +} + +export interface AnswerTraceKnowledgeHit { + vectorType?: string; + knowledgeId?: string; + title?: string; + summary?: string; + snippet?: string; + source?: string; + concreteType?: string; +} + +export interface AnswerTraceToolStep { + toolName?: string; + title?: string; + summary?: string; + detail?: string; + datasource?: string; + timestampEpochMs?: number; +} + +export interface AnswerTraceExplain { + sessionId: string; + runtimeRequestId: string; + agentId?: string; + question?: string; + answer?: string; + datasource?: string; + sql?: string; + sqlExplanation?: string; + semanticHits: AnswerTraceSemanticHit[]; + knowledgeHits: AnswerTraceKnowledgeHit[]; + toolSteps: AnswerTraceToolStep[]; + usedTables: string[]; + usedColumns: string[]; + permissions: Record; + stats: Record; + warnings: string[]; + updatedAt: number; +} + export interface TraceSpan { name: string; spanId: string; @@ -121,6 +170,13 @@ class ChatService { return response.data; } + async getAnswerExplain(sessionId: string, runtimeRequestId: string): Promise { + const response = await axios.get( + `${API_BASE_URL}/sessions/${sessionId}/answers/${runtimeRequestId}/explain`, + ); + return response.data; + } + /** * 保存消息到会话 * @param sessionId 会话ID diff --git a/data-agent-frontend/src/services/graph.ts b/data-agent-frontend/src/services/graph.ts index 42b529b40..89d7e1802 100644 --- a/data-agent-frontend/src/services/graph.ts +++ b/data-agent-frontend/src/services/graph.ts @@ -17,9 +17,8 @@ export interface GraphRequest { agentId: string; threadId?: string; + runtimeRequestId?: string; query: string; - humanFeedback: boolean; - humanFeedbackContent?: string; rejectedPlan: boolean; nl2sqlOnly: boolean; } @@ -67,15 +66,13 @@ class GraphService { if (request.threadId) { params.append('threadId', request.threadId); } + if (request.runtimeRequestId) { + params.append('runtimeRequestId', request.runtimeRequestId); + } params.append('query', request.query); - params.append('humanFeedback', request.humanFeedback.toString()); params.append('rejectedPlan', request.rejectedPlan.toString()); params.append('nl2sqlOnly', request.nl2sqlOnly.toString()); - if (request.humanFeedbackContent) { - params.append('humanFeedbackContent', request.humanFeedbackContent); - } - const url = `${API_BASE_URL}/stream/search?${params.toString()}`; const eventSource = new EventSource(url); diff --git a/data-agent-frontend/src/services/resultSet.ts b/data-agent-frontend/src/services/resultSet.ts index 461833ecd..3e13d6102 100644 --- a/data-agent-frontend/src/services/resultSet.ts +++ b/data-agent-frontend/src/services/resultSet.ts @@ -41,11 +41,3 @@ export interface PaginationConfig { pageSize: number; total: number; } - -/** - * 结果集显示配置 - */ -export interface ResultSetDisplayConfig { - showSqlResults: boolean; - pageSize: number; -} diff --git a/data-agent-frontend/src/services/sessionStateManager.ts b/data-agent-frontend/src/services/sessionStateManager.ts index 6accac334..d287dac7d 100644 --- a/data-agent-frontend/src/services/sessionStateManager.ts +++ b/data-agent-frontend/src/services/sessionStateManager.ts @@ -16,6 +16,7 @@ import { ref, Ref } from 'vue'; import { GraphNodeResponse, GraphRequest } from '@/services/graph.ts'; +import { AnswerTraceExplain } from '@/services/chat.ts'; export interface SessionRuntimeState { isStreaming: boolean; @@ -26,6 +27,86 @@ export interface SessionRuntimeState { htmlReportContent: string; htmlReportSize: number; markdownReportContent: string; + answerExplain: AnswerTraceExplain | null; + answerExplainVisible: boolean; +} + +// 可持久化的状态字段(不包括函数和临时状态) +interface PersistableState { + nodeBlocks: GraphNodeResponse[][]; + persistedBlockCount: number; + lastRequest: GraphRequest | null; + htmlReportContent: string; + htmlReportSize: number; + markdownReportContent: string; + answerExplain: AnswerTraceExplain | null; + answerExplainVisible: boolean; +} + +const STORAGE_KEY_PREFIX = 'session_state_'; +const MAX_STORAGE_SIZE_MB = 4; // 最大存储 4MB +const MAX_NODE_BLOCKS = 10; // 最多保存 10 个 nodeBlocks + +/** + * 从 sessionStorage 加载状态 + */ +function loadStateFromStorage(sessionId: string): Partial | null { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + const stored = sessionStorage.getItem(key); + if (stored) { + return JSON.parse(stored); + } + } catch (error) { + console.error('加载会话状态失败:', error); + } + return null; +} + +/** + * 保存状态到 sessionStorage(带大小限制) + */ +function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + + // 只保存最近的 nodeBlocks,避免数据过大 + const persistable: PersistableState = { + nodeBlocks: state.nodeBlocks.slice(-MAX_NODE_BLOCKS), + persistedBlockCount: state.persistedBlockCount, + lastRequest: state.lastRequest, + htmlReportContent: state.htmlReportContent, + htmlReportSize: state.htmlReportSize, + markdownReportContent: state.markdownReportContent, + answerExplain: state.answerExplain, + answerExplainVisible: state.answerExplainVisible, + }; + + const json = JSON.stringify(persistable); + const sizeInMB = new Blob([json]).size / (1024 * 1024); + + // 检查大小限制 + if (sizeInMB > MAX_STORAGE_SIZE_MB) { + console.warn(`会话状态过大 (${sizeInMB.toFixed(2)}MB),跳过保存`); + return; + } + + sessionStorage.setItem(key, json); + } catch (error) { + console.error('保存会话状态失败:', error); + } +} + +/** + * 从 sessionStorage 删除状态 + */ +function removeStateFromStorage(sessionId: string) { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + sessionStorage.removeItem(key); + } catch (error) { + console.error('删除会话状态失败:', error); + } } /** @@ -40,15 +121,20 @@ export function useSessionStateManager() { */ const getSessionState = (sessionId: string): SessionRuntimeState => { if (!sessionStates.value.has(sessionId)) { + // 尝试从 sessionStorage 加载 + const stored = loadStateFromStorage(sessionId); + sessionStates.value.set(sessionId, { isStreaming: false, - nodeBlocks: [], - persistedBlockCount: 0, + nodeBlocks: stored?.nodeBlocks ?? [], + persistedBlockCount: stored?.persistedBlockCount ?? 0, closeStream: null, - lastRequest: null, - htmlReportContent: '', - htmlReportSize: 0, - markdownReportContent: '', + lastRequest: stored?.lastRequest ?? null, + htmlReportContent: stored?.htmlReportContent ?? '', + htmlReportSize: stored?.htmlReportSize ?? 0, + markdownReportContent: stored?.markdownReportContent ?? '', + answerExplain: stored?.answerExplain ?? null, + answerExplainVisible: stored?.answerExplainVisible ?? false, }); } return sessionStates.value.get(sessionId)!; @@ -62,11 +148,19 @@ export function useSessionStateManager() { viewState: { isStreaming: Ref; nodeBlocks: Ref; + answerExplain?: Ref; + answerExplainVisible?: Ref; }, ) => { const state = getSessionState(sessionId); viewState.isStreaming.value = state.isStreaming; viewState.nodeBlocks.value = state.nodeBlocks; + if (viewState.answerExplain) { + viewState.answerExplain.value = state.answerExplain; + } + if (viewState.answerExplainVisible) { + viewState.answerExplainVisible.value = state.answerExplainVisible; + } }; /** @@ -77,11 +171,22 @@ export function useSessionStateManager() { viewState: { isStreaming: Ref; nodeBlocks: Ref; + answerExplain?: Ref; + answerExplainVisible?: Ref; }, ) => { const state = getSessionState(sessionId); state.isStreaming = viewState.isStreaming.value; state.nodeBlocks = viewState.nodeBlocks.value; + if (viewState.answerExplain) { + state.answerExplain = viewState.answerExplain.value; + } + if (viewState.answerExplainVisible) { + state.answerExplainVisible = viewState.answerExplainVisible.value; + } + + // 保存到 sessionStorage(带大小限制) + saveStateToStorage(sessionId, state); }; /** @@ -93,6 +198,8 @@ export function useSessionStateManager() { state.closeStream(); } sessionStates.value.delete(sessionId); + // 同时删除 sessionStorage 中的数据 + removeStateFromStorage(sessionId); }; /** diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 0ac684b0c..9226af92e 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -55,13 +55,15 @@ :class="message.messageType === 'text' ? ['message-container', message.role] : ''" > -
+
+
+
@@ -184,13 +186,6 @@
- - -
@@ -219,48 +214,11 @@ :onQuestionClick="handlePresetQuestionClick" />
-
- 人工反馈 - - - -
-
- 仅NL2SQL - -
-
- 自动Scroll - -
-
- 显示SQL结果 - - - -
每页数量 @@ -270,6 +228,16 @@
+
+ + 查看数据来源 + +
diff --git a/data-agent-frontend/src/components/run/ChatSessionSidebar.vue b/data-agent-frontend/src/components/run/ChatSessionSidebar.vue index fd1c1503f..432fe2172 100644 --- a/data-agent-frontend/src/components/run/ChatSessionSidebar.vue +++ b/data-agent-frontend/src/components/run/ChatSessionSidebar.vue @@ -278,7 +278,7 @@ } try { - await ChatService.renameSession(session.id, newTitle); + await ChatService.renameSession(session.id, Number(agentId.value), newTitle); session.title = newTitle; session.editing = false; ElMessage.success('会话标题已更新'); @@ -296,6 +296,30 @@ // 计算属性 const agentId = computed(() => route.params.id as string); + const parseAgentId = (value: unknown): number | null => { + if (typeof value === 'number' && Number.isFinite(value)) { + return value; + } + if (typeof value === 'string' && value.trim()) { + const parsed = Number(value); + return Number.isFinite(parsed) ? parsed : null; + } + return null; + }; + + const getRouteAgentId = (): number | null => { + const rawAgentId = route.params.id; + return parseAgentId(Array.isArray(rawAgentId) ? rawAgentId[0] : rawAgentId); + }; + + const requireRouteAgentId = (): number => { + const resolvedAgentId = getRouteAgentId(); + if (resolvedAgentId === null) { + throw new Error('智能体ID无效,请刷新后重试'); + } + return resolvedAgentId; + }; + // 方法 const goBack = () => { router.push(`/agent/${agentId.value}`); @@ -303,7 +327,7 @@ const loadSessions = async () => { try { - sessions.value = await ChatService.getAgentSessions(parseInt(agentId.value)); + sessions.value = await ChatService.getAgentSessions(requireRouteAgentId()); // 默认选择第一个会话或创建新会话 if (sessions.value.length > 0) { await props.handleSelectSession(sessions.value[0]); @@ -318,7 +342,7 @@ const createNewSession = async () => { try { - const newSession = await ChatService.createSession(parseInt(agentId.value), '新会话'); + const newSession = await ChatService.createSession(requireRouteAgentId(), '新会话'); sessions.value.unshift(newSession); await props.handleSelectSession(newSession); ElMessage.success('新会话创建成功'); @@ -330,7 +354,7 @@ const togglePinSession = async (session: ChatSession) => { try { - await ChatService.pinSession(session.id, !session.isPinned); + await ChatService.pinSession(session.id, requireRouteAgentId(), !session.isPinned); session.isPinned = !session.isPinned; ElMessage.success(session.isPinned ? '会话已置顶' : '会话已取消置顶'); } catch (error) { @@ -346,7 +370,7 @@ cancelButtonText: '取消', type: 'warning', }); - await ChatService.deleteSession(session.id); + await ChatService.deleteSession(session.id, requireRouteAgentId()); props.handleDeleteSessionState(session.id); sessions.value = sessions.value.filter((s: ChatSession) => s.id !== session.id); if (props.handleGetCurrentSession() == session) { @@ -368,7 +392,7 @@ cancelButtonText: '取消', type: 'warning', }); - await ChatService.clearAgentSessions(parseInt(agentId.value)); + await ChatService.clearAgentSessions(requireRouteAgentId()); sessions.value.forEach((session: ChatSession) => { props.handleDeleteSessionState(session.id); }); diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index ff014c772..de188d7c9 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -116,13 +116,24 @@ export interface SessionTrace { const API_BASE_URL = '/api'; +const resolveAgentId = (agentId: number | string): number => { + const resolvedAgentId = typeof agentId === 'number' ? agentId : Number(agentId); + if (!Number.isFinite(resolvedAgentId)) { + throw new Error('智能体ID无效,请刷新后重试'); + } + return resolvedAgentId; +}; + class ChatService { /** * 获取Agent的会话列表 * @param agentId Agent ID */ async getAgentSessions(agentId: number): Promise { - const response = await axios.get(`${API_BASE_URL}/agent/${agentId}/sessions`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get( + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, + ); return response.data; } @@ -133,13 +144,14 @@ class ChatService { * @param userId 用户ID */ async createSession(agentId: number, title?: string, userId?: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); const request = { title, userId, }; const response = await axios.post( - `${API_BASE_URL}/agent/${agentId}/sessions`, + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, request, ); return response.data; @@ -150,7 +162,10 @@ class ChatService { * @param agentId Agent ID */ async clearAgentSessions(agentId: number): Promise { - const response = await axios.delete(`${API_BASE_URL}/agent/${agentId}/sessions`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.delete( + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, + ); return response.data; } @@ -158,21 +173,32 @@ class ChatService { * 获取会话的消息列表 * @param sessionId 会话ID */ - async getSessionMessages(sessionId: string): Promise { + async getSessionMessages(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.get( `${API_BASE_URL}/sessions/${sessionId}/messages`, + { params: { agentId: resolvedAgentId } }, ); return response.data; } - async getSessionTrace(sessionId: string): Promise { - const response = await axios.get(`${API_BASE_URL}/sessions/${sessionId}/trace`); + async getSessionTrace(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get(`${API_BASE_URL}/sessions/${sessionId}/trace`, { + params: { agentId: resolvedAgentId }, + }); return response.data; } - async getAnswerExplain(sessionId: string, runtimeRequestId: string): Promise { + async getAnswerExplain( + sessionId: string, + runtimeRequestId: string, + agentId: number, + ): Promise { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.get( `${API_BASE_URL}/sessions/${sessionId}/answers/${runtimeRequestId}/explain`, + { params: { agentId: resolvedAgentId } }, ); return response.data; } @@ -182,8 +208,9 @@ class ChatService { * @param sessionId 会话ID * @param message 消息对象 */ - async saveMessage(sessionId: string, message: ChatMessage): Promise { + async saveMessage(sessionId: string, agentId: number, message: ChatMessage): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); // 设置会话ID const messageData = { ...message, @@ -193,6 +220,7 @@ class ChatService { const response = await axios.post( `${API_BASE_URL}/sessions/${sessionId}/messages`, messageData, + { params: { agentId: resolvedAgentId } }, ); return response.data; } catch (error) { @@ -208,13 +236,14 @@ class ChatService { * @param sessionId 会话ID * @param isPinned 是否置顶 */ - async pinSession(sessionId: string, isPinned: boolean): Promise { + async pinSession(sessionId: string, agentId: number, isPinned: boolean): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.put( `${API_BASE_URL}/sessions/${sessionId}/pin`, null, { - params: { isPinned }, + params: { agentId: resolvedAgentId, isPinned }, }, ); return response.data; @@ -234,8 +263,9 @@ class ChatService { * @param sessionId 会话ID * @param title 新标题 */ - async renameSession(sessionId: string, title: string): Promise { + async renameSession(sessionId: string, agentId: number, title: string): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); if (!title || title.trim().length === 0) { throw new Error('标题不能为空'); } @@ -244,7 +274,7 @@ class ChatService { `${API_BASE_URL}/sessions/${sessionId}/rename`, null, { - params: { title: title.trim() }, + params: { agentId: resolvedAgentId, title: title.trim() }, }, ); return response.data; @@ -263,9 +293,12 @@ class ChatService { * 删除单个会话 * @param sessionId 会话ID */ - async deleteSession(sessionId: string): Promise { + async deleteSession(sessionId: string, agentId: number): Promise { try { - const response = await axios.delete(`${API_BASE_URL}/sessions/${sessionId}`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.delete(`${API_BASE_URL}/sessions/${sessionId}`, { + params: { agentId: resolvedAgentId }, + }); return response.data; } catch (error) { if (axios.isAxiosError(error) && error.response?.status === 500) { @@ -280,12 +313,13 @@ class ChatService { * @param sessionId 会话ID * @param content 报告内容 */ - async downloadHtmlReport(sessionId: string, content: string): Promise { + async downloadHtmlReport(sessionId: string, agentId: number, content: string): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.post( `${API_BASE_URL}/sessions/${sessionId}/reports/html`, content, - { + { params: { agentId: resolvedAgentId }, responseType: 'blob', // 重要:设置响应类型为blob headers: { 'Content-Type': 'text/plain;charset=utf-8', // 明确设置内容类型和编码 diff --git a/data-agent-frontend/src/services/semanticModel.ts b/data-agent-frontend/src/services/semanticModel.ts index a19699d4a..4776daaf3 100644 --- a/data-agent-frontend/src/services/semanticModel.ts +++ b/data-agent-frontend/src/services/semanticModel.ts @@ -35,6 +35,7 @@ interface SemanticModel { interface SemanticModelAddDto { agentId: number; + datasourceId: number; tableName: string; columnName: string; businessName: string; @@ -44,6 +45,14 @@ interface SemanticModelAddDto { dataType: string; } +interface SemanticModelUpdateDto { + businessName: string; + synonyms: string; + businessDescription: string; + columnComment: string; + dataType: string; +} + interface SemanticModelImportItem { tableName: string; columnName: string; @@ -56,6 +65,7 @@ interface SemanticModelImportItem { interface SemanticModelBatchImportDTO { agentId: number; + datasourceId: number; items: SemanticModelImportItem[]; } @@ -68,6 +78,22 @@ interface BatchImportResult { const API_BASE_URL = '/api/semantic-model'; +const extractApiErrorMessage = (error: unknown, fallback: string): string => { + if (axios.isAxiosError(error)) { + const responseMessage = error.response?.data?.message; + if (typeof responseMessage === 'string' && responseMessage.trim()) { + return responseMessage; + } + if (typeof error.message === 'string' && error.message.trim()) { + return error.message; + } + } + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; +}; + class SemanticModelService { /** * 获取语义模型列表 @@ -104,8 +130,15 @@ class SemanticModelService { * @param model 语义模型 DTO 对象 */ async create(model: SemanticModelAddDto): Promise { - const response = await axios.post(API_BASE_URL, model); - return response.data.success; + try { + const response = await axios.post(API_BASE_URL, model); + if (response.data.success) { + return true; + } + throw new Error(response.data.message || '创建失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, '创建失败')); + } } /** @@ -113,15 +146,18 @@ class SemanticModelService { * @param id 语义模型 ID * @param model 语义模型对象 */ - async update(id: number, model: SemanticModel): Promise { + async update(id: number, model: SemanticModelUpdateDto): Promise { try { const response = await axios.put(`${API_BASE_URL}/${id}`, model); - return response.data.success; + if (response.data.success) { + return true; + } + throw new Error(response.data.message || '更新失败'); } catch (error) { if (axios.isAxiosError(error) && error.response?.status === 404) { return false; } - throw error; + throw new Error(extractApiErrorMessage(error, '更新失败')); } } @@ -175,11 +211,18 @@ class SemanticModelService { * @param dto 批量导入DTO */ async batchImport(dto: SemanticModelBatchImportDTO): Promise { - const response = await axios.post>( - `${API_BASE_URL}/batch-import`, - dto, - ); - return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + try { + const response = await axios.post>( + `${API_BASE_URL}/batch-import`, + dto, + ); + if (response.data.success) { + return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + } + throw new Error(response.data.message || '批量导入失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, '批量导入失败')); + } } /** @@ -187,21 +230,29 @@ class SemanticModelService { * @param file Excel文件 * @param agentId 智能体ID */ - async importExcel(file: File, agentId: number): Promise { + async importExcel(file: File, agentId: number, datasourceId: number): Promise { const formData = new FormData(); formData.append('file', file); formData.append('agentId', agentId.toString()); + formData.append('datasourceId', datasourceId.toString()); - const response = await axios.post>( - `${API_BASE_URL}/import/excel`, - formData, - { - headers: { - 'Content-Type': 'multipart/form-data', + try { + const response = await axios.post>( + `${API_BASE_URL}/import/excel`, + formData, + { + headers: { + 'Content-Type': 'multipart/form-data', + }, }, - }, - ); - return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + ); + if (response.data.success) { + return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + } + throw new Error(response.data.message || 'Excel导入失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, 'Excel导入失败')); + } } /** @@ -228,6 +279,7 @@ export default new SemanticModelService(); export type { SemanticModel, SemanticModelAddDto, + SemanticModelUpdateDto, SemanticModelImportItem, SemanticModelBatchImportDTO, BatchImportResult, diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 61071b3b9..ed47b13c7 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -980,9 +980,41 @@ const agentId = computed(() => route.params.id as string); + const parseAgentId = (value: unknown): number | null => { + if (typeof value === 'number' && Number.isFinite(value)) { + return value; + } + if (typeof value === 'string' && value.trim()) { + const parsed = Number(value); + return Number.isFinite(parsed) ? parsed : null; + } + return null; + }; + + const getRouteAgentId = (): number | null => { + const rawAgentId = route.params.id; + return parseAgentId(Array.isArray(rawAgentId) ? rawAgentId[0] : rawAgentId); + }; + + const getResolvedAgentId = (): number | null => + parseAgentId(currentSession.value?.agentId) ?? parseAgentId(agent.value?.id) ?? getRouteAgentId(); + + const requireResolvedAgentId = (): number => { + const resolvedAgentId = getResolvedAgentId(); + if (resolvedAgentId === null) { + throw new Error('智能体ID无效,请刷新后重试'); + } + return resolvedAgentId; + }; + const loadAgent = async () => { try { - const agentData = await AgentService.get(parseInt(agentId.value)); + const routeAgentId = getRouteAgentId(); + if (routeAgentId === null) { + ElMessage.error('智能体ID无效,请刷新后重试'); + return; + } + const agentData = await AgentService.get(routeAgentId); if (agentData) { agent.value = agentData; } else { @@ -1030,7 +1062,10 @@ answerExplainVisible, pendingClarify, }); - currentMessages.value = await ChatService.getSessionMessages(session.id); + currentMessages.value = await ChatService.getSessionMessages( + session.id, + requireResolvedAgentId(), + ); scrollToBottom(); } catch (error) { ElMessage.error('加载消息失败'); @@ -1090,12 +1125,16 @@ }; try { // 保存用户消息 - const savedMessage = await ChatService.saveMessage(currentSession.value.id, userMessage); + const savedMessage = await ChatService.saveMessage( + currentSession.value.id, + requireResolvedAgentId(), + userMessage, + ); currentMessages.value.push(savedMessage); getSessionState(currentSession.value.id); const request: GraphRequest = { - agentId: agentId.value, + agentId: String(requireResolvedAgentId()), query: requestQuery, humanFeedback: Boolean(activeClarify), humanFeedbackContent: feedbackContent, @@ -1165,7 +1204,7 @@ messageType: 'result-set', metadata: metadataJson, }; - await ChatService.saveMessage(sessionId, aiMessage); + await ChatService.saveMessage(sessionId, requireResolvedAgentId(), aiMessage); return; } } catch (error) { @@ -1181,7 +1220,7 @@ messageType: 'html', metadata: metadataJson, }; - await ChatService.saveMessage(sessionId, aiMessage); + await ChatService.saveMessage(sessionId, requireResolvedAgentId(), aiMessage); }; const sendGraphRequest = async (request: GraphRequest) => { @@ -1394,7 +1433,7 @@ messageType: 'html-report', }; - await ChatService.saveMessage(sessionId, htmlReportMessage) + await ChatService.saveMessage(sessionId, requireResolvedAgentId(), htmlReportMessage) .then(savedMessage => { if (currentSession.value?.id === sessionId) { currentMessages.value.push(savedMessage); @@ -1419,7 +1458,7 @@ messageType: 'markdown-report', }; - await ChatService.saveMessage(sessionId, markdownMessage) + await ChatService.saveMessage(sessionId, requireResolvedAgentId(), markdownMessage) .then(savedMessage => { if (currentSession.value?.id === sessionId) { currentMessages.value.push(savedMessage); @@ -1501,7 +1540,11 @@ return; } try { - await ChatService.downloadHtmlReport(currentSession.value.id, content); + await ChatService.downloadHtmlReport( + currentSession.value.id, + requireResolvedAgentId(), + content, + ); ElMessage.success('HTML报告下载成功'); } catch (error) { console.error('下载HTML报告失败:', error); @@ -1764,6 +1807,7 @@ answerExplain.value = await ChatService.getAnswerExplain( currentSession.value.id, runtimeRequestId, + requireResolvedAgentId(), ); // 保存到会话状态(包括 sessionStorage) saveViewToState(currentSession.value.id, { @@ -2135,7 +2179,10 @@ traceLoading.value = true; traceError.value = ''; try { - sessionTrace.value = await ChatService.getSessionTrace(currentSession.value.id); + sessionTrace.value = await ChatService.getSessionTrace( + currentSession.value.id, + requireResolvedAgentId(), + ); selectedTraceSpanId.value = sessionTrace.value.rootSpans?.[0]?.spanId ?? ''; } catch (error: any) { sessionTrace.value = null; @@ -2178,7 +2225,7 @@ // 如果没有会话,先创建新会话 if (!currentSession.value) { try { - const newSession = await ChatService.createSession(parseInt(agentId.value), '新会话'); + const newSession = await ChatService.createSession(requireResolvedAgentId(), '新会话'); currentSession.value = newSession; ElMessage.success('新会话创建成功'); } catch (error) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index 1be862a06..35958bd86 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -450,7 +450,8 @@ private Map loadTableDocumentMap(ExplorerContext context, List return Collections.emptyMap(); } try { - return schemaService.getTableDocuments(context.datasource().getId(), tableNames) + return schemaService.getTableDocuments(context.agentDatasource().getAgentId().toString(), + context.datasource().getId(), tableNames) .stream() .collect(Collectors.toMap(doc -> normalizeTableName(String.valueOf(doc.getMetadata().get("name"))), doc -> doc, (left, right) -> left, LinkedHashMap::new)); @@ -462,7 +463,8 @@ private Map loadTableDocumentMap(ExplorerContext context, List private Map loadColumnDocumentMap(ExplorerContext context, String tableName) { try { - return schemaService.getColumnDocumentsByTableName(context.datasource().getId(), List.of(tableName)) + return schemaService.getColumnDocumentsByTableName(context.agentDatasource().getAgentId().toString(), + context.datasource().getId(), List.of(tableName)) .stream() .collect(Collectors.toMap(doc -> normalizeColumnName(String.valueOf(doc.getMetadata().get("name"))), doc -> doc, (left, right) -> left, LinkedHashMap::new)); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index 47760d236..943e14e6a 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -44,7 +44,7 @@ public class DomainBusinessKnowledgeToolSupport { }, "knowledgeTypes": { "type": "array", - "description": "可选。限定知识范围。支持 businessTerm、agentKnowledge、document、qa、faq、all。", + "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。兼容旧别名 businessTerm。", "items": { "type": "string" } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java index acd846b8f..576d8bd83 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java @@ -20,10 +20,12 @@ import com.alibaba.cloud.ai.dataagent.enums.BizDataSourceTypeEnum; import com.alibaba.cloud.ai.dataagent.enums.ErrorCodeEnum; import com.alibaba.druid.pool.DruidDataSourceFactory; +import com.alibaba.druid.pool.DruidDataSource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import javax.sql.DataSource; +import java.sql.DriverManager; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -93,7 +95,13 @@ public DataSource createdDataSource(String url, String username, String password log.info("Creating Hive DataSource with custom configuration"); String driver = getDriver(); Map props = new HiveDruidProperties(driver, url, username, password, "stat").toMap(); - return DruidDataSourceFactory.createDataSource(props); + DruidDataSource dataSource = (DruidDataSource) DruidDataSourceFactory.createDataSource(props); + dataSource.setInitialSize(0); + dataSource.setMinIdle(0); + dataSource.setBreakAfterAcquireFailure(true); + dataSource.setConnectionErrorRetryAttempts(2); + dataSource.setTestWhileIdle(false); + return dataSource; } private static final class HiveDruidProperties { @@ -123,14 +131,14 @@ private Map toMap() { props.put(DruidDataSourceFactory.PROP_USERNAME, this.username); props.put(DruidDataSourceFactory.PROP_PASSWORD, this.password); props.put(DruidDataSourceFactory.PROP_FILTERS, this.filters); - props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "5"); - props.put(DruidDataSourceFactory.PROP_MINIDLE, "5"); + props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "0"); + props.put(DruidDataSourceFactory.PROP_MINIDLE, "0"); props.put(DruidDataSourceFactory.PROP_MAXACTIVE, "20"); props.put(DruidDataSourceFactory.PROP_MAXWAIT, "60000"); props.put(DruidDataSourceFactory.PROP_TIMEBETWEENEVICTIONRUNSMILLIS, "60000"); props.put(DruidDataSourceFactory.PROP_MINEVICTABLEIDLETIMEMILLIS, "300000"); props.put(DruidDataSourceFactory.PROP_VALIDATIONQUERY, "SELECT 1"); - props.put(DruidDataSourceFactory.PROP_TESTWHILEIDLE, "true"); + props.put(DruidDataSourceFactory.PROP_TESTWHILEIDLE, "false"); props.put(DruidDataSourceFactory.PROP_TESTONBORROW, "false"); props.put(DruidDataSourceFactory.PROP_TESTONRETURN, "false"); return props; @@ -141,7 +149,8 @@ private Map toMap() { @Override public ErrorCodeEnum ping(DbConfigBO config) { log.info("Hive ping method called, url: {}", config.getUrl()); - try (Connection connection = getConnection(config); Statement stmt = connection.createStatement()) { + try (Connection connection = DriverManager.getConnection(config.getUrl(), config.getUsername(), + config.getPassword()); Statement stmt = connection.createStatement()) { log.info("Hive connection obtained, executing SELECT 1"); ResultSet rs = stmt.executeQuery("SELECT 1"); if (rs.next()) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java index c720a1baa..4f6aba33f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java @@ -117,6 +117,7 @@ public Connection getConnection(DbConfigBO config) { log.warn("Attempt {} to get database connection failed: {}", attempt, e.getMessage()); if (attempt == maxRetries) { + evict(config); log.error("Failed to get database connection after {} attempts, URL: {}", maxRetries, jdbcUrl, e); throw new RuntimeException("Failed to get database connection after " + maxRetries + " attempts", e); @@ -141,15 +142,27 @@ public Connection getConnection(DbConfigBO config) { * @param password the database password * @return the cache key */ - private String generateCacheKey(String url, String username, String password) { + protected String generateCacheKey(String url, String username, String password) { return url + "|" + username + "|" + Objects.hashCode(password); } + @Override + public void evict(DbConfigBO config) { + if (config == null || config.getUrl() == null) { + return; + } + String cacheKey = generateCacheKey(config.getUrl(), config.getUsername(), config.getPassword()); + DataSource dataSource = DATA_SOURCE_CACHE.remove(cacheKey); + if (dataSource instanceof DruidDataSource druidDataSource) { + druidDataSource.close(); + } + } + @Override public void close() { DATA_SOURCE_CACHE.values().forEach(dataSource -> { - if (dataSource instanceof DruidDataSource) { - ((DruidDataSource) dataSource).close(); + if (dataSource instanceof DruidDataSource druidDataSource) { + druidDataSource.close(); } }); DATA_SOURCE_CACHE.clear(); @@ -175,20 +188,23 @@ public DataSource createdDataSource(String url, String username, String password props.put(DruidDataSourceFactory.PROP_URL, url); props.put(DruidDataSourceFactory.PROP_USERNAME, username); props.put(DruidDataSourceFactory.PROP_PASSWORD, password); - props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "5"); - props.put(DruidDataSourceFactory.PROP_MINIDLE, "5"); + props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "0"); + props.put(DruidDataSourceFactory.PROP_MINIDLE, "0"); props.put(DruidDataSourceFactory.PROP_MAXACTIVE, "20"); props.put(DruidDataSourceFactory.PROP_MAXWAIT, "10000"); props.put(DruidDataSourceFactory.PROP_TIMEBETWEENEVICTIONRUNSMILLIS, "60000"); props.put(DruidDataSourceFactory.PROP_FILTERS, filters); DruidDataSource dataSource = (DruidDataSource) DruidDataSourceFactory.createDataSource(props); + dataSource.setInitialSize(0); + dataSource.setMinIdle(0); dataSource.setBreakAfterAcquireFailure(Boolean.TRUE); dataSource.setConnectionErrorRetryAttempts(2); + dataSource.setTestWhileIdle(false); // 记录数据源创建信息 log.info( - "Created new DataSource with optimized parameters - InitialSize: 5, MinIdle: 5, MaxActive: 20, MaxWait: 10000ms"); + "Created new DataSource with optimized parameters - InitialSize: 0, MinIdle: 0, MaxActive: 20, MaxWait: 10000ms"); return dataSource; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java index 5876a275a..7380850b6 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java @@ -41,6 +41,13 @@ public interface DBConnectionPool extends AutoCloseable { */ Connection getConnection(DbConfigBO config); + /** + * Evict cached data source for the given configuration if present. + * @param config the database configuration + */ + default void evict(DbConfigBO config) { + } + boolean supportedDataSourceType(String type); String getConnectionPoolType(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java index 2dc335098..f16395e9d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java @@ -38,6 +38,7 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.*; +import org.springframework.web.server.ResponseStatusException; import java.util.List; import java.util.Map; @@ -125,13 +126,16 @@ public ResponseEntity clearAgentSessions(@PathVariable(value = "id" * Get message list for a session */ @GetMapping("/sessions/{sessionId}/messages") - public ResponseEntity> getSessionMessages(@PathVariable(value = "sessionId") String sessionId) { - List messages = chatMessageService.findVisibleBySessionId(sessionId); + public ResponseEntity> getSessionMessages(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + List messages = chatMessageService.findVisibleBySessionId(sessionId, agentId); return ResponseEntity.ok(messages); } @GetMapping("/sessions/{sessionId}/trace") - public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId") String sessionId) { + public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); return sessionTraceStore.getLatestTrace(sessionId) .>map(ResponseEntity::ok) .orElseGet(() -> ResponseEntity.notFound().build()); @@ -139,18 +143,22 @@ public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId" @GetMapping("/sessions/{sessionId}/answers/{runtimeRequestId}/explain") public ResponseEntity getAnswerExplain(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, @PathVariable(value = "runtimeRequestId") String runtimeRequestId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); return answerTraceExplainStore.getExplain(sessionId, runtimeRequestId) .>map(ResponseEntity::ok) - .or(() -> loadPersistedAnswerExplain(sessionId, runtimeRequestId).map(ResponseEntity::ok)) + .or(() -> loadPersistedAnswerExplain(sessionId, runtimeRequestId, agentId).map(ResponseEntity::ok)) .orElseGet(() -> ResponseEntity.notFound().build()); } - private java.util.Optional loadPersistedAnswerExplain(String sessionId, String runtimeRequestId) { + private java.util.Optional loadPersistedAnswerExplain(String sessionId, String runtimeRequestId, + Long agentId) { if (!StringUtils.hasText(sessionId) || !StringUtils.hasText(runtimeRequestId)) { return java.util.Optional.empty(); } - List snapshots = chatMessageService.findBySessionIdAndMessageType(sessionId, ANSWER_EXPLAIN_MESSAGE_TYPE); + List snapshots = chatMessageService.findBySessionIdAndMessageType(sessionId, + ANSWER_EXPLAIN_MESSAGE_TYPE, agentId); for (ChatMessage snapshot : snapshots) { if (snapshot == null || !StringUtils.hasText(snapshot.getContent())) { continue; @@ -174,6 +182,7 @@ private java.util.Optional loadPersistedAnswerExplain(String sessionId */ @PostMapping("/sessions/{sessionId}/messages") public ResponseEntity saveMessage(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, @RequestBody ChatMessageDTO request) { try { if (request == null) { @@ -187,10 +196,10 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" .metadata(request.getMetadata()) .build(); - ChatMessage savedMessage = chatMessageService.saveMessage(message); + ChatMessage savedMessage = chatMessageService.saveMessage(message, agentId); // Update session activity time - chatSessionService.updateSessionTime(sessionId); + chatSessionService.updateSessionTime(sessionId, agentId); if (shouldGenerateTitle(request, savedMessage)) { sessionTitleService.scheduleTitleGeneration(sessionId, message.getContent()); @@ -198,6 +207,9 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" return ResponseEntity.ok(savedMessage); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Save message error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().build(); @@ -209,12 +221,16 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" */ @PutMapping("/sessions/{sessionId}/pin") public ResponseEntity pinSession(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, @RequestParam(value = "isPinned") Boolean isPinned) { try { - chatSessionService.pinSession(sessionId, isPinned); + chatSessionService.pinSession(sessionId, isPinned, agentId); String message = isPinned ? "会话已置顶" : "会话已取消置顶"; return ResponseEntity.ok(ApiResponse.success(message)); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Pin session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("操作失败")); @@ -226,15 +242,19 @@ public ResponseEntity pinSession(@PathVariable(value = "sessionId") */ @PutMapping("/sessions/{sessionId}/rename") public ResponseEntity renameSession(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, @RequestParam(value = "title") String title) { try { if (!StringUtils.hasText(title)) { return ResponseEntity.badRequest().body(ApiResponse.error("标题不能为空")); } - chatSessionService.renameSession(sessionId, title.trim()); + chatSessionService.renameSession(sessionId, title.trim(), agentId); return ResponseEntity.ok(ApiResponse.success("会话已重命名")); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Rename session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("重命名失败")); @@ -245,11 +265,15 @@ public ResponseEntity renameSession(@PathVariable(value = "sessionI * Delete a single session */ @DeleteMapping("/sessions/{sessionId}") - public ResponseEntity deleteSession(@PathVariable(value = "sessionId") String sessionId) { + public ResponseEntity deleteSession(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { try { - chatSessionService.deleteSession(sessionId); + chatSessionService.deleteSession(sessionId, agentId); return ResponseEntity.ok(ApiResponse.success("会话已删除")); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Delete session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("删除失败")); @@ -261,11 +285,13 @@ public ResponseEntity deleteSession(@PathVariable(value = "sessionI */ @PostMapping("/sessions/{sessionId}/reports/html") public ResponseEntity convertAndDownloadHtml(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, @RequestBody String content) { try { if (!StringUtils.hasText(content)) { return ResponseEntity.badRequest().build(); } + chatSessionService.requireSessionForAgent(sessionId, agentId); log.debug("Download HTML report for session {}", sessionId); StringBuilder htmlContent = new StringBuilder(); htmlContent.append(reportTemplateUtil.getHeader()); @@ -278,6 +304,9 @@ public ResponseEntity convertAndDownloadHtml(@PathVariable(value = "sess headers.setContentDispositionFormData("attachment", filename); return ResponseEntity.ok().headers(headers).body(htmlContent.toString().getBytes(StandardCharsets.UTF_8)); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Download HTML report error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().build(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java index 4ce636377..ad4785497 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java @@ -17,7 +17,9 @@ import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelAddDTO; import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelBatchImportDTO; +import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelUpdateDTO; import com.alibaba.cloud.ai.dataagent.entity.SemanticModel; +import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.service.semantic.SemanticModelService; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; import com.alibaba.cloud.ai.dataagent.vo.BatchImportResult; @@ -83,23 +85,26 @@ public ApiResponse get(@PathVariable(value = "id") Long id) { @PostMapping public ApiResponse create(@RequestBody @Validated SemanticModelAddDTO semanticModelAddDto) { - boolean success = semanticModelService.addSemanticModel(semanticModelAddDto); - if (success) { - return ApiResponse.success("Semantic model created successfully", true); - } - else { + try { + boolean success = semanticModelService.addSemanticModel(semanticModelAddDto); + if (success) { + return ApiResponse.success("Semantic model created successfully", true); + } return ApiResponse.error("Failed to create semantic model"); } + catch (IllegalArgumentException e) { + throw new InvalidInputException(e.getMessage()); + } } @PutMapping("/{id}") - public ApiResponse update(@PathVariable(value = "id") Long id, @RequestBody SemanticModel model) { + public ApiResponse update(@PathVariable(value = "id") Long id, + @RequestBody @Validated SemanticModelUpdateDTO semanticModelUpdateDto) { if (semanticModelService.getById(id) == null) { return ApiResponse.error("Semantic model not found"); } - model.setId(id); - semanticModelService.updateSemanticModel(id, model); - return ApiResponse.success("Semantic model updated successfully", model); + semanticModelService.updateSemanticModel(id, semanticModelUpdateDto); + return ApiResponse.success("Semantic model updated successfully", semanticModelService.getById(id)); } @DeleteMapping("/{id}") @@ -172,8 +177,9 @@ public ResponseEntity downloadTemplate() { @PostMapping(value = "/import/excel", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) public Mono> importExcel(@RequestPart("file") FilePart file, - @RequestPart("agentId") String agentId) { + @RequestPart("agentId") String agentId, @RequestPart("datasourceId") String datasourceId) { Long agentIdLong = Long.parseLong(agentId); + Integer datasourceIdInt = Integer.parseInt(datasourceId); String filename = file.filename(); return DataBufferUtils.join(file.content()).flatMap(dataBuffer -> { @@ -183,7 +189,8 @@ public Mono> importExcel(@RequestPart("file") Fil return Mono.fromCallable(() -> { try (InputStream inputStream = new ByteArrayInputStream(bytes)) { - BatchImportResult result = semanticModelService.importFromExcel(inputStream, filename, agentIdLong); + BatchImportResult result = semanticModelService.importFromExcel(inputStream, filename, agentIdLong, + datasourceIdInt); return ApiResponse.success("Excel导入完成", result); } }).subscribeOn(Schedulers.boundedElastic()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java index 6d4a95038..cde7e64fa 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java @@ -26,6 +26,8 @@ public class SchemaInitRequest implements Serializable { private DbConfigBO dbConfig; + private Long agentId; + private List tables; private Map> visibleColumnsByTable; @@ -38,6 +40,14 @@ public void setDbConfig(DbConfigBO dbConfig) { this.dbConfig = dbConfig; } + public Long getAgentId() { + return agentId; + } + + public void setAgentId(Long agentId) { + this.agentId = agentId; + } + public List getTables() { return tables; } @@ -56,8 +66,8 @@ public void setVisibleColumnsByTable(Map> visibleColumnsByT @Override public String toString() { - return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", tables=" + tables + ", visibleColumnsByTable=" - + visibleColumnsByTable + '}'; + return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", agentId=" + agentId + ", tables=" + tables + + ", visibleColumnsByTable=" + visibleColumnsByTable + '}'; } @Override @@ -67,13 +77,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; SchemaInitRequest that = (SchemaInitRequest) o; - return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(tables, that.tables) + return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(agentId, that.agentId) + && Objects.equals(tables, that.tables) && Objects.equals(visibleColumnsByTable, that.visibleColumnsByTable); } @Override public int hashCode() { - return Objects.hash(dbConfig, tables, visibleColumnsByTable); + return Objects.hash(dbConfig, agentId, tables, visibleColumnsByTable); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java index 44bfdfda9..2a9a42136 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java @@ -33,6 +33,9 @@ public class SemanticModelAddDTO { @NotNull(message = "智能体ID不能为空") private Long agentId; + @NotNull(message = "数据源ID不能为空") + private Integer datasourceId; + /** 关联的表名 */ @NotBlank(message = "表名不能为空") private String tableName; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java index 5d7d083aa..b96041ebc 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java @@ -37,6 +37,9 @@ public class SemanticModelBatchImportDTO { @NotNull(message = "智能体ID不能为空") private Long agentId; + @NotNull(message = "数据源ID不能为空") + private Integer datasourceId; + @NotEmpty(message = "导入数据不能为空") @Valid private List items; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java new file mode 100644 index 000000000..c2447f963 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.dto.schema; + +import jakarta.validation.constraints.NotBlank; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** 语义模型更新 DTO,仅允许编辑业务解释字段。 */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class SemanticModelUpdateDTO { + + @NotBlank(message = "业务名称不能为空") + private String businessName; + + private String synonyms; + + private String businessDescription; + + private String columnComment; + + @NotBlank(message = "数据类型不能为空") + private String dataType; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java index f6cdb0b41..c93e68967 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java @@ -47,6 +47,10 @@ public interface AgentDatasourceMapper { AgentDatasource selectByAgentIdAndDatasourceId(@Param("agentId") Long agentId, @Param("datasourceId") Integer datasourceId); + /** Query associations by datasource ID */ + @Select("SELECT * FROM agent_datasource WHERE datasource_id = #{datasourceId} ORDER BY create_time DESC") + List selectByDatasourceId(@Param("datasourceId") Integer datasourceId); + /** Disable all data sources for an agent */ @Update("UPDATE agent_datasource SET is_active = 0 WHERE agent_id = #{agentId}") int disableAllByAgentId(@Param("agentId") Long agentId); @@ -59,6 +63,9 @@ AgentDatasource selectByAgentIdAndDatasourceId(@Param("agentId") Long agentId, int countActiveByAgentIdExcluding(@Param("agentId") Long agentId, @Param("excludeDatasourceId") Integer excludeDatasourceId); + @Select("SELECT COUNT(*) FROM agent_datasource WHERE agent_id = #{agentId} AND is_active = 1") + int countActiveByAgentId(@Param("agentId") Long agentId); + @Delete("DELETE FROM agent_datasource WHERE datasource_id = #{datasourceId}") int deleteAllByDatasourceId(@Param("datasourceId") Integer datasourceId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java index 4da15fb5f..6cb4e1407 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java @@ -32,6 +32,14 @@ public interface LogicalRelationMapper { @Select("SELECT * FROM logical_relation WHERE id = #{id} AND is_deleted = 0") LogicalRelation selectById(@Param("id") Integer id); + @Select(""" + SELECT * FROM logical_relation + WHERE id = #{id} + AND datasource_id = #{datasourceId} + AND is_deleted = 0 + """) + LogicalRelation selectByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + /** * 根据数据源ID查询逻辑外键列表(未删除的) */ @@ -71,12 +79,51 @@ public interface LogicalRelationMapper { """) int updateById(LogicalRelation logicalRelation); + @Update(""" + + """) + int updateByIdAndDatasourceId(@Param("datasourceId") Integer datasourceId, + @Param("logicalRelation") LogicalRelation logicalRelation); + /** * 逻辑删除外键 */ @Update("UPDATE logical_relation SET is_deleted = 1, updated_time = NOW() WHERE id = #{id}") int deleteById(@Param("id") Integer id); + @Update(""" + UPDATE logical_relation + SET is_deleted = 1, updated_time = NOW() + WHERE id = #{id} + AND datasource_id = #{datasourceId} + AND is_deleted = 0 + """) + int deleteByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + + @Delete("DELETE FROM logical_relation WHERE id = #{id}") + int hardDeleteById(@Param("id") Integer id); + + @Delete(""" + DELETE FROM logical_relation + WHERE id = #{id} + AND datasource_id = #{datasourceId} + """) + int hardDeleteByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + /** * 逻辑删除数据源下的所有逻辑外键 */ @@ -99,4 +146,33 @@ int checkExists(@Param("datasourceId") Integer datasourceId, @Param("sourceTable @Param("sourceColumnName") String sourceColumnName, @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName); + @Select(""" + SELECT COUNT(*) FROM logical_relation + WHERE datasource_id = #{datasourceId} + AND source_table_name = #{sourceTableName} + AND source_column_name = #{sourceColumnName} + AND target_table_name = #{targetTableName} + AND target_column_name = #{targetColumnName} + AND is_deleted = 0 + AND id != #{excludeId} + """) + int checkExistsExcludingId(@Param("datasourceId") Integer datasourceId, + @Param("sourceTableName") String sourceTableName, @Param("sourceColumnName") String sourceColumnName, + @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName, + @Param("excludeId") Integer excludeId); + + @Select(""" + SELECT * FROM logical_relation + WHERE datasource_id = #{datasourceId} + AND source_table_name = #{sourceTableName} + AND source_column_name = #{sourceColumnName} + AND target_table_name = #{targetTableName} + AND target_column_name = #{targetColumnName} + AND is_deleted = 1 + ORDER BY updated_time DESC, id DESC + """) + List selectDeletedByBusinessKey(@Param("datasourceId") Integer datasourceId, + @Param("sourceTableName") String sourceTableName, @Param("sourceColumnName") String sourceColumnName, + @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName); + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java index c2d8ed935..adac16d21 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java @@ -131,16 +131,11 @@ List selectEnabledByAgentIdAndDatasourceId(@Param("agentId") Long ") int removeColumnsOutsideTables(@Param("agentDatasourceId") int agentDatasourceId, @Param("tables") List tables); @@ -46,9 +42,7 @@ int removeColumnsOutsideTables(@Param("agentDatasourceId") int agentDatasourceId @Insert("") + + "(#{row.agentDatasourceId}, #{row.tableName}, #{row.columnName})" + "" + "") int insertColumns(@Param("rows") List rows); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java index adac16d21..d2ee4a295 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java @@ -195,8 +195,8 @@ List selectEnabledByAgentIdAndDatasourceIdAndTableNames(@Param("a LIMIT 1 """) SemanticModel selectByAgentIdAndDatasourceIdAndTableNameAndColumnName(@Param("agentId") Long agentId, - @Param("datasourceId") Integer datasourceId, - @Param("tableName") String tableName, @Param("columnName") String columnName); + @Param("datasourceId") Integer datasourceId, @Param("tableName") String tableName, + @Param("columnName") String columnName); @Delete(""" DELETE FROM semantic_model diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java index abba93af3..e12ad0d41 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java @@ -47,8 +47,8 @@ public class AnswerTraceExplainStore { private final ThreadLocal currentContext = new ThreadLocal<>(); - private final LinkedHashMap> explainsBySession = new LinkedHashMap<>(32, - 0.75f, true); + private final LinkedHashMap> explainsBySession = new LinkedHashMap<>( + 32, 0.75f, true); public void openScope(GraphRequest request) { if (request == null || !StringUtils.hasText(request.getThreadId()) @@ -96,7 +96,8 @@ public void recordSemanticSearch(String query, String summary, List applySemanticSearch(assembly, query, summary, hits)); } - public void recordSemanticSearch(GraphRequest request, String query, String summary, List hits) { + public void recordSemanticSearch(GraphRequest request, String query, String summary, + List hits) { withAssembly(request, assembly -> applySemanticSearch(assembly, query, summary, hits)); } @@ -203,7 +204,8 @@ private void applyRequestContext(ExplainAssembly assembly, GraphRequest request) } } - private void applySemanticSearch(ExplainAssembly assembly, String query, String summary, List hits) { + private void applySemanticSearch(ExplainAssembly assembly, String query, String summary, + List hits) { assembly.toolSteps.add(ToolStepView.builder() .toolName("semantic_model.search") .title("语义匹配") @@ -255,7 +257,8 @@ private void applyKnowledgeSearch(ExplainAssembly assembly, DomainKnowledgeSearc } } if (result.warnings() != null) { - assembly.warnings.addAll(result.warnings().stream().filter(StringUtils::hasText).map(String::trim).toList()); + assembly.warnings + .addAll(result.warnings().stream().filter(StringUtils::hasText).map(String::trim).toList()); } assembly.updatedAt = Instant.now().toEpochMilli(); } @@ -405,6 +408,7 @@ private AnswerTraceExplainView toView() { .updatedAt(updatedAt) .build(); } + } @Data @@ -453,6 +457,7 @@ public static class AnswerTraceExplainView { private List warnings = List.of(); private long updatedAt; + } @Data @@ -473,6 +478,7 @@ public static class SemanticHitView { private Integer score; private String relationHint; + } @Data @@ -493,6 +499,7 @@ public static class KnowledgeHitView { private String source; private String concreteType; + } @Data @@ -511,6 +518,7 @@ public static class ToolStepView { private String datasource; private long timestampEpochMs; + } @Data @@ -528,6 +536,7 @@ public static class ExplainMirrorSummary { private int knowledgeHitCount; private int toolStepCount; + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java index cf941c030..fd0e6d908 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/SessionTraceStore.java @@ -33,8 +33,8 @@ import org.springframework.util.StringUtils; /** - * Cache the latest completed trace for each chat session so the frontend can inspect - * the most recent AgentScope span chain without querying an external tracing backend. + * Cache the latest completed trace for each chat session so the frontend can inspect the + * most recent AgentScope span chain without querying an external tracing backend. */ @Component public class SessionTraceStore implements SpanExporter { @@ -56,9 +56,9 @@ public class SessionTraceStore implements SpanExporter { private static final String META_OMITTED_ATTRIBUTE_COUNT = "_meta.omitted_attribute_count"; private static final List SENSITIVE_ATTRIBUTE_KEY_TOKENS = List.of("password", "passwd", "pwd", "secret", - "api_key", "apikey", "api_token", "apitoken", "access_key", "access_token", "accesstoken", - "refresh_token", "refreshtoken", "id_token", "idtoken", "auth_token", "authtoken", "private_key", - "privatekey", "authorization", "cookie", "credential", "signature"); + "api_key", "apikey", "api_token", "apitoken", "access_key", "access_token", "accesstoken", "refresh_token", + "refreshtoken", "id_token", "idtoken", "auth_token", "authtoken", "private_key", "privatekey", + "authorization", "cookie", "credential", "signature"); private final Object monitor = new Object(); @@ -240,11 +240,7 @@ private TraceView toTraceView() { roots.sort(Comparator.comparingLong(node -> node.span.getStartEpochMs())); List rootViews = roots.stream().map(MutableTreeNode::toImmutable).toList(); SpanView rootSpan = rootViews.isEmpty() ? null : rootViews.get(0); - long startedAt = spansBySpanId.values() - .stream() - .mapToLong(SpanView::getStartEpochMs) - .min() - .orElse(0L); + long startedAt = spansBySpanId.values().stream().mapToLong(SpanView::getStartEpochMs).min().orElse(0L); long endedAt = spansBySpanId.values().stream().mapToLong(SpanView::getEndEpochMs).max().orElse(0L); long durationMs = Math.max(0L, endedAt - startedAt); return new TraceView(sessionId, traceId, runtimeRequestId, agentId, startedAt, endedAt, durationMs, @@ -265,9 +261,9 @@ private MutableTreeNode(SpanView span) { private SpanView toImmutable() { children.sort(Comparator.comparingLong(node -> node.span.getStartEpochMs())); - return new SpanView(span.getName(), span.getSpanId(), span.getParentSpanId(), span.getKind(), span.getStatus(), - span.getStartEpochMs(), span.getEndEpochMs(), span.getDurationMs(), span.getAttributes(), - children.stream().map(MutableTreeNode::toImmutable).toList()); + return new SpanView(span.getName(), span.getSpanId(), span.getParentSpanId(), span.getKind(), + span.getStatus(), span.getStartEpochMs(), span.getEndEpochMs(), span.getDurationMs(), + span.getAttributes(), children.stream().map(MutableTreeNode::toImmutable).toList()); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/agent/AgentStartupInitialization.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/agent/AgentStartupInitialization.java index e04afe893..edef456e7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/agent/AgentStartupInitialization.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/agent/AgentStartupInitialization.java @@ -164,7 +164,8 @@ private boolean isAlreadyInitialized(Long agentId, Integer datasourceId) { return agentVectorStoreService.hasSchemaDocuments(agentIdStr, datasourceIdStr); } catch (Exception e) { - log.error("Failed to check initialization status for agent: {} and datasource: {}, assuming not initialized", + log.error( + "Failed to check initialization status for agent: {} and datasource: {}, assuming not initialized", agentId, datasourceId, e); return false; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java index e0294f185..a68743513 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/AgentDatasourceService.java @@ -41,8 +41,8 @@ default AgentDatasource getCurrentAgentDatasource(Long agentId) { AgentDatasource updateDatasourceTables(Long agentId, Integer datasourceId, List tables); - AgentDatasource updateDatasourceColumns(Long agentId, Integer datasourceId, Map> columnsByTable) - throws Exception; + AgentDatasource updateDatasourceColumns(Long agentId, Integer datasourceId, + Map> columnsByTable) throws Exception; List getVisibleTableColumns(Long agentId, Integer datasourceId, String tableName) throws Exception; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java index d8b3f642e..3560594c3 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java @@ -73,7 +73,8 @@ public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer data // Create database configuration DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); - AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); + AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, + datasourceId); if (agentDatasource == null) { throw new RuntimeException("Agent datasource relation not found with agentId=%s, datasourceId=%s" .formatted(agentId, datasourceId)); @@ -202,7 +203,8 @@ public AgentDatasource updateDatasourceTables(Long agentId, Integer datasourceId normalizedTables = sanitizeRequestedTables(tables, datasourceTableIndex); } catch (Exception ex) { - throw new IllegalArgumentException("Failed to validate datasource tables: %s".formatted(ex.getMessage()), ex); + throw new IllegalArgumentException("Failed to validate datasource tables: %s".formatted(ex.getMessage()), + ex); } if (normalizedTables.isEmpty()) { tablesMapper.removeAllTables(datasource.getId()); @@ -218,8 +220,7 @@ public AgentDatasource updateDatasourceTables(Long agentId, Integer datasourceId @Override @Transactional public AgentDatasource updateDatasourceColumns(Long agentId, Integer datasourceId, - Map> columnsByTable) - throws Exception { + Map> columnsByTable) throws Exception { if (agentId == null || datasourceId == null || columnsByTable == null) { throw new IllegalArgumentException("参数不能为空"); } @@ -280,7 +281,8 @@ private AgentDatasource refreshAgentDatasource(Long agentId, Integer datasourceI } private Map> loadSelectedColumns(int agentDatasourceId) { - List rows = Optional.ofNullable(columnsMapper.getAgentDatasourceColumns(agentDatasourceId)) + List rows = Optional + .ofNullable(columnsMapper.getAgentDatasourceColumns(agentDatasourceId)) .orElse(List.of()); Map> columnsByTable = new LinkedHashMap<>(); for (AgentDatasourceColumn row : rows) { @@ -293,18 +295,20 @@ private Map> loadSelectedColumns(int agentDatasourceId) { return Map.copyOf(columnsByTable); } - private TableResolutionIndex loadAllowedTables(AgentDatasource agentDatasource, Integer datasourceId) throws Exception { + private TableResolutionIndex loadAllowedTables(AgentDatasource agentDatasource, Integer datasourceId) + throws Exception { List datasourceTables = datasourceService.getDatasourceTables(datasourceId); TableResolutionIndex datasourceTableIndex = buildTableResolutionIndex(datasourceTables); - List selectedTables = Optional.ofNullable(tablesMapper.getAgentDatasourceTables(agentDatasource.getId())) + List selectedTables = Optional + .ofNullable(tablesMapper.getAgentDatasourceTables(agentDatasource.getId())) .orElse(List.of()); List visibleTables = selectedTables.isEmpty() ? datasourceTables : sanitizeRequestedTables(selectedTables, datasourceTableIndex, true); return buildTableResolutionIndex(visibleTables); } - private Map> sanitizeColumnsByTable(Integer datasourceId, Map> columnsByTable, - TableResolutionIndex allowedTables) throws Exception { + private Map> sanitizeColumnsByTable(Integer datasourceId, + Map> columnsByTable, TableResolutionIndex allowedTables) throws Exception { Map> sanitized = new LinkedHashMap<>(); for (Map.Entry> entry : columnsByTable.entrySet()) { String requestedTableName = entry.getKey(); @@ -393,14 +397,14 @@ private String resolveTableName(String requestedTableName, TableResolutionIndex if (isQualifiedIdentifier(requestedTableName) && !allowQualifiedFallback) { return null; } - List leafMatches = tableIndex.leafTables().getOrDefault(normalizeLeafIdentifier(requestedTableName), - List.of()); + List leafMatches = tableIndex.leafTables() + .getOrDefault(normalizeLeafIdentifier(requestedTableName), List.of()); if (leafMatches.size() == 1) { return leafMatches.get(0); } if (leafMatches.size() > 1) { - throw new IllegalArgumentException("Table '%s' is ambiguous across datasource tables: %s" - .formatted(requestedTableName, leafMatches)); + throw new IllegalArgumentException( + "Table '%s' is ambiguous across datasource tables: %s".formatted(requestedTableName, leafMatches)); } return null; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java index c356135bf..36ed52955 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java @@ -39,8 +39,8 @@ record KnowledgeHit(String vectorType, String knowledgeId, String title, String } record SearchDiagnostics(String runtimeAgentId, Integer recalledBusinessKnowledgeCount, - Integer recalledBusinessTermCount, Integer recalledAgentKnowledgeCount, boolean businessKnowledgeVectorReady, - boolean businessTermVectorReady, boolean agentKnowledgeVectorReady) { + Integer recalledBusinessTermCount, Integer recalledAgentKnowledgeCount, + boolean businessKnowledgeVectorReady, boolean businessTermVectorReady, boolean agentKnowledgeVectorReady) { } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java index e78aa5566..91e2c1ec0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java @@ -231,8 +231,8 @@ private List searchBusinessKnowledge(String agentId, String query, if (knowledge.getIsRecall() == null || knowledge.getIsRecall() != 1) { continue; } - hits.add(new KnowledgeHit("businessKnowledge", String.valueOf(knowledgeId), - knowledge.getBusinessTerm(), abbreviate(knowledge.getDescription(), MAX_SUMMARY_LENGTH), + hits.add(new KnowledgeHit("businessKnowledge", String.valueOf(knowledgeId), knowledge.getBusinessTerm(), + abbreviate(knowledge.getDescription(), MAX_SUMMARY_LENGTH), abbreviate(document.getText(), MAX_SNIPPET_LENGTH), "businessKnowledge#" + knowledgeId, null)); } return hits; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java index 5bc010360..78e742f5f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/schema/SchemaServiceImpl.java @@ -252,7 +252,8 @@ private List> partitionList(List list, int batchSize) { return partitions; } - protected void storeSchemaDocuments(Long agentId, Integer datasourceId, List columns, List tables) { + protected void storeSchemaDocuments(Long agentId, Integer datasourceId, List columns, + List tables) { // 串行去批写入,并行流的时候有API限速了 List> columnBatches = batchingStrategy.batch(columns); for (List batch : columnBatches) { @@ -290,7 +291,8 @@ protected void clearSchemaDataForDatasource(Long agentId, Integer datasourceId) agentVectorStoreService.deleteDocumentsByMetadata(metadata); } - private void applyVisibleColumnRestrictions(List tables, Map> visibleColumnsByTable) { + private void applyVisibleColumnRestrictions(List tables, + Map> visibleColumnsByTable) { Map> normalizedRestrictions = normalizeVisibleColumnRestrictions(visibleColumnsByTable); if (normalizedRestrictions.isEmpty()) { return; @@ -316,7 +318,8 @@ private void applyVisibleColumnRestrictions(List tables, Map> normalizeVisibleColumnRestrictions(Map> visibleColumnsByTable) { + private Map> normalizeVisibleColumnRestrictions( + Map> visibleColumnsByTable) { Map> normalizedRestrictions = new LinkedHashMap<>(); Optional.ofNullable(visibleColumnsByTable).orElse(Map.of()).forEach((tableName, columns) -> { String normalizedTableName = normalizeIdentifier(tableName); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/semantic/SemanticModelServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/semantic/SemanticModelServiceImpl.java index e8723de6a..41148a082 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/semantic/SemanticModelServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/semantic/SemanticModelServiceImpl.java @@ -222,8 +222,8 @@ public BatchImportResult batchImport(SemanticModelBatchImportDTO dto) { log.error("导入第{}条记录失败: datasourceId={}, tableName={}, columnName={}", i + 1, dto.getDatasourceId(), item.getTableName(), item.getColumnName(), e); result.setFailCount(result.getFailCount() + 1); - result.addError( - String.format("第%d条记录导入失败(%s.%s): %s", i + 1, item.getTableName(), item.getColumnName(), e.getMessage())); + result.addError(String.format("第%d条记录导入失败(%s.%s): %s", i + 1, item.getTableName(), item.getColumnName(), + e.getMessage())); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/vectorstore/AgentVectorStoreServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/vectorstore/AgentVectorStoreServiceImpl.java index a50f68efc..ba19ed3ea 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/vectorstore/AgentVectorStoreServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/vectorstore/AgentVectorStoreServiceImpl.java @@ -267,10 +267,12 @@ public List getDocumentsOnlyByFilter(Filter.Expression filterExpressio @Override public boolean hasDocuments(String agentId) { - return hasDocumentsByMetadata(agentId, Map.of(Constant.AGENT_ID, agentId)); /* - // 类似 MySQL 的 LIMIT 1,只检查是否存在文档 - return hasDocumentsByMetadata(agentId, Map.of(Constant.AGENT_ID, agentId)); - */ + return hasDocumentsByMetadata(agentId, Map.of(Constant.AGENT_ID, + agentId)); /* + * // 类似 MySQL 的 LIMIT 1,只检查是否存在文档 return + * hasDocumentsByMetadata(agentId, Map.of(Constant.AGENT_ID, + * agentId)); + */ } @Override diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DocumentConverterUtil.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DocumentConverterUtil.java index 6c19923fe..ca444b957 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DocumentConverterUtil.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/DocumentConverterUtil.java @@ -35,7 +35,8 @@ @Slf4j public class DocumentConverterUtil { - public static List convertColumnsToDocuments(Long agentId, Integer datasourceId, List tables) { + public static List convertColumnsToDocuments(Long agentId, Integer datasourceId, + List tables) { List documents = new ArrayList<>(); for (TableInfoBO table : tables) { // 使用已经处理过的列数据,避免重复查询 @@ -99,7 +100,8 @@ public static Document convertTableToDocument(Long agentId, Integer datasourceId return new Document(text, metadata); } - public static List convertTablesToDocuments(Long agentId, Integer datasourceId, List tables) { + public static List convertTablesToDocuments(Long agentId, Integer datasourceId, + List tables) { return tables.stream() .map(table -> DocumentConverterUtil.convertTableToDocument(agentId, datasourceId, table)) .collect(Collectors.toList()); From be1156d01c2ec584b1a7809874ca1bfbe427668e Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Tue, 28 Apr 2026 16:43:52 +0800 Subject: [PATCH 12/22] fix: ci --- .../src/components/agent/SemanticsConfig.vue | 731 +++++++++--------- .../src/components/run/ResultSetDisplay.vue | 2 +- data-agent-frontend/src/services/chat.ts | 9 +- data-agent-frontend/src/views/AgentRun.vue | 28 +- 4 files changed, 395 insertions(+), 375 deletions(-) diff --git a/data-agent-frontend/src/components/agent/SemanticsConfig.vue b/data-agent-frontend/src/components/agent/SemanticsConfig.vue index aabeea0d8..8a859eda8 100644 --- a/data-agent-frontend/src/components/agent/SemanticsConfig.vue +++ b/data-agent-frontend/src/components/agent/SemanticsConfig.vue @@ -166,7 +166,10 @@ - + @@ -194,402 +197,400 @@ diff --git a/data-agent-frontend/src/components/run/ResultSetDisplay.vue b/data-agent-frontend/src/components/run/ResultSetDisplay.vue index 9b82a9275..c5ebc1653 100644 --- a/data-agent-frontend/src/components/run/ResultSetDisplay.vue +++ b/data-agent-frontend/src/components/run/ResultSetDisplay.vue @@ -230,7 +230,7 @@ display: flex; justify-content: space-between; align-items: center; - //margin-left: 15px; + /* margin-left: 15px; */ margin-bottom: 12px; padding: 8px 0; border-bottom: 1px solid #ebeef5; diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index de188d7c9..248d446a1 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -208,7 +208,11 @@ class ChatService { * @param sessionId 会话ID * @param message 消息对象 */ - async saveMessage(sessionId: string, agentId: number, message: ChatMessage): Promise { + async saveMessage( + sessionId: string, + agentId: number, + message: ChatMessage, + ): Promise { try { const resolvedAgentId = resolveAgentId(agentId); // 设置会话ID @@ -319,7 +323,8 @@ class ChatService { const response = await axios.post( `${API_BASE_URL}/sessions/${sessionId}/reports/html`, content, - { params: { agentId: resolvedAgentId }, + { + params: { agentId: resolvedAgentId }, responseType: 'blob', // 重要:设置响应类型为blob headers: { 'Content-Type': 'text/plain;charset=utf-8', // 明确设置内容类型和编码 diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index ed47b13c7..652382843 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -254,7 +254,10 @@ riskLevel={{ pendingClarify.riskLevel }}
- {{ pendingClarify.summary || '当前问题存在高歧义,下一条输入将作为补充信息或显式假设回填。' }} + {{ + pendingClarify.summary || + '当前问题存在高歧义,下一条输入将作为补充信息或显式假设回填。' + }}
@@ -997,7 +1002,9 @@ }; const getResolvedAgentId = (): number | null => - parseAgentId(currentSession.value?.agentId) ?? parseAgentId(agent.value?.id) ?? getRouteAgentId(); + parseAgentId(currentSession.value?.agentId) ?? + parseAgentId(agent.value?.id) ?? + getRouteAgentId(); const requireResolvedAgentId = (): number => { const resolvedAgentId = getResolvedAgentId(); @@ -1156,8 +1163,7 @@ } ElMessage.error('未知错误'); console.error(error); - } - finally { + } finally { isSubmittingMessage.value = false; } }; @@ -1433,7 +1439,11 @@ messageType: 'html-report', }; - await ChatService.saveMessage(sessionId, requireResolvedAgentId(), htmlReportMessage) + await ChatService.saveMessage( + sessionId, + requireResolvedAgentId(), + htmlReportMessage, + ) .then(savedMessage => { if (currentSession.value?.id === sessionId) { currentMessages.value.push(savedMessage); @@ -1458,7 +1468,11 @@ messageType: 'markdown-report', }; - await ChatService.saveMessage(sessionId, requireResolvedAgentId(), markdownMessage) + await ChatService.saveMessage( + sessionId, + requireResolvedAgentId(), + markdownMessage, + ) .then(savedMessage => { if (currentSession.value?.id === sessionId) { currentMessages.value.push(savedMessage); From 96ccc4736116696796e7c25da7ccffcb66bb6f02 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Tue, 28 Apr 2026 18:19:53 +0800 Subject: [PATCH 13/22] fix: remove something --- .claude/code-review.md | 328 ------------------ .../impl/AgentDatasourceServiceImpl.java | 2 +- 2 files changed, 1 insertion(+), 329 deletions(-) delete mode 100644 .claude/code-review.md diff --git a/.claude/code-review.md b/.claude/code-review.md deleted file mode 100644 index e56073d9f..000000000 --- a/.claude/code-review.md +++ /dev/null @@ -1,328 +0,0 @@ -# 代码审查报告 - -## 📋 概览 - -本次提交主要涉及以下改动: -1. **前端**:Trace 对话框优化、数据来源持久化、UI 样式改进 -2. **后端**:工具调用增强、trace 数据结构优化 - ---- - -## ✅ 优点 - -### 1. Trace 对话框优化 -- ✅ **树形展示**:增加了缩进量(32px),层级关系更清晰 -- ✅ **视觉优化**:添加了渐变背景、阴影效果、hover 动画 -- ✅ **消息去重**:实现了智能去重逻辑,避免重复显示工具调用 -- ✅ **截断处理**:对被截断的 JSON 数据进行了容错处理 - -### 2. 数据来源持久化 -- ✅ **sessionStorage 持久化**:刷新页面后数据来源可以恢复 -- ✅ **会话隔离**:每个会话的状态独立保存 -- ✅ **状态同步**:切换会话时正确保存和恢复状态 - -### 3. 代码质量 -- ✅ **类型安全**:使用 TypeScript 接口定义状态结构 -- ✅ **错误处理**:添加了 try-catch 保护 -- ✅ **注释完善**:关键函数都有中文注释 - ---- - -## ⚠️ 问题与建议 - -### 🔴 严重问题 - -#### 1. **parsedTraceConversations 逻辑过于简化** -**位置**:`AgentRun.vue:2461-2530` - -**问题**: -```typescript -// 只处理 agentscope.function.input 和 agentscope.function.output -const inputEntry = row.attributeEntries.find(e => e.key === 'agentscope.function.input'); -const outputEntry = row.attributeEntries.find(e => e.key === 'agentscope.function.output'); - -if (inputEntry && outputEntry) { - // ... 创建消息 - return [{ ... }]; -} - -// 如果没有找到 agentscope.function 属性,返回空 -return []; -``` - -**影响**: -- ❌ 如果 span 中没有 `agentscope.function.input/output`,将不显示任何消息 -- ❌ 其他类型的消息(如 `gen_ai.*`、`tool.*` 等)会被完全忽略 -- ❌ 用户可能看到空白的 trace 详情面板 - -**建议**: -```typescript -// 优先处理 agentscope.function,但如果不存在则回退到默认逻辑 -if (inputEntry && outputEntry) { - // ... 处理 agentscope.function - return [{ ... }]; -} - -// 回退到默认逻辑 -const keyword = normalizedTraceSearchKeyword.value; -return dedupeTraceConversationGroups( - row.attributeEntries.flatMap(entry => extractTraceConversationGroupsFromEntry(entry)), -) - .map(group => { ... }) - .filter(group => group.messages.length > 0); -``` - ---- - -#### 2. **删除了人工反馈功能但没有清理相关代码** -**位置**:`AgentRun.vue` - -**问题**: -- ❌ 删除了 `HumanFeedback` 组件的引用和使用 -- ❌ 但保留了 `handleNl2sqlOnlyChange` 函数(已无用) -- ❌ 保留了 `showHumanFeedback`、`lastRequest` 等变量声明(未使用) - -**建议**: -完全清理人工反馈相关代码: -```typescript -// 删除这些未使用的变量和函数 -const handleNl2sqlOnlyChange = (value: boolean) => { ... }; // 删除 -const showHumanFeedback = ref(false); // 删除 -const lastRequest = ref(null); // 删除 -``` - ---- - -### 🟡 中等问题 - -#### 3. **sessionStorage 可能超出存储限制** -**位置**:`sessionStateManager.ts:66-82` - -**问题**: -```typescript -function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { - try { - const persistable: PersistableState = { - nodeBlocks: state.nodeBlocks, // 可能很大 - answerExplain: state.answerExplain, // 可能很大 - // ... - }; - sessionStorage.setItem(key, JSON.stringify(persistable)); - } catch (error) { - console.error('保存会话状态失败:', error); - } -} -``` - -**影响**: -- ⚠️ `nodeBlocks` 和 `answerExplain` 可能包含大量数据 -- ⚠️ sessionStorage 通常限制为 5-10MB -- ⚠️ 超出限制时会静默失败,用户不知道状态未保存 - -**建议**: -1. 添加存储大小检查 -2. 只保存必要的字段 -3. 对大数据进行压缩或截断 -```typescript -function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { - try { - const persistable: PersistableState = { - // 只保存最后 10 个 nodeBlocks - nodeBlocks: state.nodeBlocks.slice(-10), - answerExplain: state.answerExplain, - // ... - }; - - const json = JSON.stringify(persistable); - const sizeInMB = new Blob([json]).size / (1024 * 1024); - - if (sizeInMB > 4) { - console.warn(`会话状态过大 (${sizeInMB.toFixed(2)}MB),跳过保存`); - return; - } - - sessionStorage.setItem(key, json); - } catch (error) { - console.error('保存会话状态失败:', error); - // 提示用户 - ElMessage.warning('会话状态保存失败,刷新页面后可能丢失部分数据'); - } -} -``` - ---- - -#### 4. **去重逻辑可能过于激进** -**位置**:`AgentRun.vue:2050-2060` - -**问题**: -```typescript -const getTraceMessageDedupFingerprint = (message: ParsedTraceMessage) => { - if (message.kind === 'tool-call' || message.kind === 'tool-result') { - // 不包含 title,只比较 content 和 details - return [ - message.kind, - stringifyTraceSemanticPayload(message.content), - stringifyTraceSemanticPayload(message.details), - ].join('|'); - } - return getTraceMessageFingerprint(message); -}; -``` - -**影响**: -- ⚠️ 如果两个不同的工具调用参数相同,会被认为是重复的 -- ⚠️ 例如:连续两次调用 `GET_TABLE_SCHEMA` 查询同一张表 - -**建议**: -考虑添加时间戳或调用 ID 到指纹中: -```typescript -const getTraceMessageDedupFingerprint = (message: ParsedTraceMessage) => { - if (message.kind === 'tool-call' || message.kind === 'tool-result') { - return [ - message.kind, - message.id, // 添加消息 ID - stringifyTraceSemanticPayload(message.content), - stringifyTraceSemanticPayload(message.details), - ].join('|'); - } - return getTraceMessageFingerprint(message); -}; -``` - ---- - -#### 5. **CSS 样式过于复杂** -**位置**:`AgentRun.vue` 样式部分 - -**问题**: -- ⚠️ 大量的渐变、阴影、动画效果 -- ⚠️ 可能影响性能,特别是在大量元素时 -- ⚠️ 维护成本高 - -**建议**: -1. 考虑使用 CSS 变量统一管理颜色和尺寸 -2. 减少不必要的渐变和阴影 -3. 使用 `will-change` 优化动画性能 - -```css -:root { - --trace-primary-color: #409eff; - --trace-border-color: #e8f1fa; - --trace-hover-shadow: 0 8px 24px rgba(64, 158, 255, 0.15); -} - -.trace-row { - border: 2px solid var(--trace-border-color); - transition: all 0.3s ease; - will-change: transform, box-shadow; -} - -.trace-row:hover { - box-shadow: var(--trace-hover-shadow); - transform: translateY(-2px); -} -``` - ---- - -### 🟢 轻微问题 - -#### 6. **缺少加载状态的用户反馈** -**位置**:`AgentRun.vue:1690-1713` - -**问题**: -```typescript -const loadAnswerExplainByRuntimeRequestId = async (runtimeRequestId: string) => { - answerExplainVisible.value = true; - answerExplainLoading.value = true; - // ... 加载数据 -}; -``` - -**建议**: -添加加载提示: -```typescript -const loadAnswerExplainByRuntimeRequestId = async (runtimeRequestId: string) => { - answerExplainVisible.value = true; - answerExplainLoading.value = true; - - const loadingMessage = ElMessage.info({ - message: '正在加载数据来源...', - duration: 0, - }); - - try { - // ... 加载数据 - } finally { - loadingMessage.close(); - answerExplainLoading.value = false; - } -}; -``` - ---- - -#### 7. **Magic Numbers** -**位置**:多处 - -**问题**: -```typescript -:style="{ paddingLeft: `${row.depth * 32 + 20}px` }" // 32 和 20 是什么? -if (outputEntry.value.length > 10000) { ... } // 10000 是什么? -``` - -**建议**: -使用常量: -```typescript -const TRACE_INDENT_SIZE = 32; -const TRACE_BASE_PADDING = 20; -const MAX_OUTPUT_LENGTH = 10000; - -:style="{ paddingLeft: `${row.depth * TRACE_INDENT_SIZE + TRACE_BASE_PADDING}px` }" -if (outputEntry.value.length > MAX_OUTPUT_LENGTH) { ... } -``` - ---- - -## 📊 统计 - -- **修改文件数**:24 个 -- **前端文件**:5 个 -- **后端文件**:19 个 -- **新增文件**:2 个 -- **删除文件**:1 个 - ---- - -## 🎯 总体评价 - -**评分**:7.5/10 - -**优点**: -- ✅ 功能实现完整,UI 优化明显 -- ✅ 代码结构清晰,类型安全 -- ✅ 持久化方案合理 - -**需要改进**: -- ❌ `parsedTraceConversations` 逻辑过于简化,需要添加回退逻辑 -- ⚠️ sessionStorage 存储大小需要控制 -- ⚠️ 清理未使用的代码 - ---- - -## 🔧 建议的修复优先级 - -1. **高优先级**:修复 `parsedTraceConversations` 的回退逻辑 -2. **中优先级**:添加 sessionStorage 大小限制 -3. **低优先级**:清理未使用的代码、优化 CSS - ---- - -## ✅ 可以提交吗? - -**建议**:⚠️ **修复高优先级问题后再提交** - -**理由**: -- `parsedTraceConversations` 的问题可能导致部分 trace 无法显示 -- 其他问题不影响核心功能,可以后续优化 diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java index 3560594c3..2b17d504b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/AgentDatasourceServiceImpl.java @@ -250,7 +250,7 @@ public List getVisibleTableColumns(Long agentId, Integer datasourceId, S } AgentDatasource agentDatasource = agentDatasourceMapper.selectByAgentIdAndDatasourceId(agentId, datasourceId); if (agentDatasource == null) { - throw new IllegalArgumentException("鏈壘鍒板搴旂殑鏁版嵁婧愬叧鑱旇褰?"); + throw new IllegalArgumentException("未找到对应的数据源关联记录"); } TableResolutionIndex allowedTables = loadAllowedTables(agentDatasource, datasourceId); String actualTableName = resolveTableName(tableName, allowedTables, false); From e60254bf86e5fbabae188ad1e2413338ae98dd83 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Tue, 28 Apr 2026 19:15:54 +0800 Subject: [PATCH 14/22] fix: use chinese replace en --- ...60\344\270\255\346\226\207\345\214\226.md" | 52 +++++++++++++++++++ .../SpringToolCallbackAgentAdapter.java | 2 +- .../impl/AiAgentRuntimeServiceImpl.java | 8 +-- .../session/AgentSessionRegistry.java | 2 +- .../template/ManagedAgentRegistry.java | 2 +- .../datasource/DatasourceExplorerService.java | 12 ++--- .../DatasourceExplorerToolProvider.java | 2 +- .../DomainBusinessKnowledgeToolProvider.java | 8 +-- .../DomainBusinessKnowledgeToolSupport.java | 2 +- .../semantic/SemanticModelSearchService.java | 14 ++--- .../semantic/SemanticModelToolProvider.java | 10 ++-- .../semantic/SemanticModelToolSupport.java | 2 +- .../BuiltinCurrentTimeSkillToolProvider.java | 4 +- .../tool/sqlguard/SqlGuardToolProvider.java | 26 +++++----- .../sqlguard/SqlVerifyExplainService.java | 36 ++++++------- 15 files changed, 117 insertions(+), 65 deletions(-) create mode 100644 ".claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" diff --git "a/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" "b/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" new file mode 100644 index 000000000..0d915af57 --- /dev/null +++ "b/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" @@ -0,0 +1,52 @@ +## 项目上下文摘要(工具描述中文化) +生成时间:2026-04-28 + +### 1. 相似实现分析 +- **实现1**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java:131-152 + - 模式:由 provider 动态构造 `ToolDefinition.description` + - 可复用:中文分点式说明风格,明确适用范围、顺序和禁用场景 + - 需注意:描述会直接影响 Agent 工具选择与调用顺序 + +- **实现2**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java:32-50 + - 模式:`INPUT_SCHEMA` 采用中文字段说明,强调“必填/可选/何时不必传” + - 可复用:schema 文案风格可直接作为其他工具输入说明参考 + - 需注意:和 provider 的 description 要保持边界一致 + +- **实现3**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java:37-63 + - 模式:support 负责统一 `ToolDefinition` 的 input schema + - 可复用:中文化时优先改 provider 的 DESCRIPTION,support 仅修正残留英文术语 + - 需注意:不要改动工具行为,只改面向 Agent 的说明文本 + +### 2. 项目约定 +- **命名约定**: Java 类、常量、方法名保持现有英文命名;说明文本统一简体中文 +- **文件组织**: provider 负责工具注册与 description;support 负责 input schema 和 callback 封装 +- **导入顺序**: 保持现有 import 顺序,不做无关调整 +- **代码风格**: 使用 Java 文本块 `"""` 保存多行描述,延续现有缩进和换行风格 + +### 3. 可复用组件清单 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java`:中文工具描述样板 +- `data-agent-management/src/main/resources/prompts/commonagent.md`:工具边界和调用顺序的权威规则 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java`:中文 schema 样板 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java`:中文 schema 样板 + +### 4. 测试策略 +- **测试框架**: 本次先做静态验证,不新增测试 +- **验证方式**: 搜索所有注册工具 description / input schema,确认英文描述已清理且工具边界与 `commonagent.md` 一致 +- **参考文件**: `commonagent.md`、各 ToolProvider / ToolSupport +- **覆盖要求**: 至少覆盖 `sql_guard.check`、`semantic_model.search`、`domain_business_knowledge.search` 以及已存在中文样板的对齐检查 + +### 5. 依赖和集成点 +- **外部依赖**: Spring AI `ToolDefinition` +- **内部依赖**: `AgentScopedToolProvider`、各 `ToolSupport` +- **集成方式**: provider 将 description 和 inputSchema 注入 `ToolDefinition.builder()` +- **配置来源**: `data-agent-management/src/main/resources/prompts/commonagent.md` 定义工具路由规则 + +### 6. 技术选型理由 +- **为什么用这个方案**: 直接修改 provider/support 文案即可完成目标,影响面准确且不改变运行逻辑 +- **优势**: 风险低、改动集中、与现有工具注册结构一致 +- **劣势和风险**: 只改文案不跑行为测试;需确保中文说明与现有路由规则完全一致,避免误导模型 + +### 7. 关键风险点 +- **边界条件**: `sql_guard.check` 同时支持 `SQL_VERIFY` 与 `DATA_PROFILE`,描述必须避免把 profile 说成默认步骤 +- **一致性风险**: provider description、schema description、`commonagent.md` 三处边界不能冲突 +- **维护风险**: 若遗漏某个注册工具,前后端展示和模型行为提示会出现中英混杂 diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java index 4620c454f..d4ed94b20 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java @@ -78,7 +78,7 @@ private ToolResultBlock invoke(ToolCallParam toolCallParam) throws Exception { } catch (Exception ex) { log.error("Spring AI tool callback execution failed. tool={}", getName(), ex); - return ToolResultBlock.error(ex.getMessage() == null ? "Tool execution failed." : ex.getMessage()) + return ToolResultBlock.error(ex.getMessage() == null ? "工具执行失败。" : ex.getMessage()) .withIdAndName(toolCallParam.getToolUseBlock().getId(), getName()); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 48a72c53d..879d52b24 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -194,7 +194,7 @@ private void emitError(Sinks.Many> sink, Grap } log.error("AgentScope runtime failed, threadId={}", threadId, error); if (sessionRegistry.isActive(threadId, runtimeRequestId)) { - String message = error.getMessage() == null ? "AgentScope runtime failed." : error.getMessage(); + String message = error.getMessage() == null ? "AgentScope 运行失败。" : error.getMessage(); sink.tryEmitNext(ServerSentEvent.builder(GraphNodeResponse.error(request.getAgentId(), threadId, message)) .event(STREAM_EVENT_ERROR) .build()); @@ -379,13 +379,13 @@ private String buildAnswerExplainMetadata(GraphRequest request) throws Exception private void validateModelConfig(ModelConfigDTO modelConfig) { if (modelConfig == null) { - throw new IllegalStateException("No active CHAT model configured. Please configure it in the dashboard."); + throw new IllegalStateException("当前未配置可用的 CHAT 模型,请先在控制台完成配置。"); } if (!StringUtils.hasText(modelConfig.getApiKey())) { - throw new IllegalStateException("Active CHAT model apiKey is empty."); + throw new IllegalStateException("当前活动 CHAT 模型的 apiKey 为空。"); } if (!StringUtils.hasText(modelConfig.getModelName())) { - throw new IllegalStateException("Active CHAT model modelName is empty."); + throw new IllegalStateException("当前活动 CHAT 模型的 modelName 为空。"); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java index 3b551852c..076687c3a 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java @@ -76,7 +76,7 @@ public void finish(String threadId, String runtimeRequestId) { private RequestExecutionState getOrCreateState(String threadId, String runtimeRequestId) { if (threadId == null || threadId.isBlank() || runtimeRequestId == null || runtimeRequestId.isBlank()) { - throw new IllegalArgumentException("threadId and runtimeRequestId must not be blank"); + throw new IllegalArgumentException("threadId 和 runtimeRequestId 不能为空"); } ConcurrentHashMap states = requestStatesByThreadId.computeIfAbsent(threadId, key -> new ConcurrentHashMap<>()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java index de142e16c..8c705a0f9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java @@ -35,7 +35,7 @@ public ManagedAgentRegistry(List managedAgents) { public ManagedAgent getRequired() { ManagedAgent managedAgent = this.agentsByType.get(normalize(CommonAgent.AGENT_TYPE)); if (managedAgent == null) { - throw new IllegalStateException("ManagedAgent registry missing type: " + CommonAgent.AGENT_TYPE); + throw new IllegalStateException("ManagedAgent 注册表缺少类型:" + CommonAgent.AGENT_TYPE); } return managedAgent; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index 38529189a..c4ad3d9aa 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -120,7 +120,7 @@ public DatasourceExplorerResult execute(String agentId, DatasourceExplorerReques public DatasourceExplorerResult execute(String agentId, DatasourceExplorerRequest request, @Nullable GraphRequest graphRequest) throws Exception { if (request == null || request.getAction() == null) { - throw new IllegalArgumentException("Datasource explorer request.action 不能为空"); + throw new IllegalArgumentException("数据源探索请求必须提供 action"); } ExplorerContext context = resolveContext(agentId); return switch (request.getAction()) { @@ -294,7 +294,7 @@ private ExplorerContext resolveContext(String agentId) throws Exception { Datasource datasource = agentDatasource.getDatasource() != null ? agentDatasource.getDatasource() : datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); if (datasource == null) { - throw new IllegalStateException("Active datasource not found for agent " + agentId); + throw new IllegalStateException("当前 Agent 未找到活动数据源:" + agentId); } DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); Accessor accessor = accessorFactory.getAccessorByDbConfig(dbConfig); @@ -342,7 +342,7 @@ private List loadPhysicalRelations(Accessor accessor, DbConfig private Long parseAgentId(String agentId) { if (!StringUtils.isNumeric(agentId)) { - throw new IllegalArgumentException("Datasource explorer 当前仅支持数值型 agentId"); + throw new IllegalArgumentException("数据源探索当前仅支持数值型 agentId"); } return Long.valueOf(agentId); } @@ -830,7 +830,7 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(exactMatches.get(0)); } if (exactMatches.size() > 1) { - throw new IllegalArgumentException("Table '%s' maps to multiple visible tables: %s".formatted(tableName, + throw new IllegalArgumentException("表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); } if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { @@ -841,14 +841,14 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(leafMatches.get(0)); } if (leafMatches.size() > 1) { - throw new IllegalArgumentException("Table '%s' is ambiguous across visible tables: %s".formatted(tableName, + throw new IllegalArgumentException("表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); } return Optional.empty(); } private IllegalArgumentException buildInvisibleTableException(ExplorerContext context, String tableName) { - return new IllegalArgumentException("Table '%s' is not visible for current agent. Visible tables: %s" + return new IllegalArgumentException("表 '%s' 对当前 Agent 不可见。当前可见表:%s" .formatted(tableName, String.join(", ", context.visibleTables()))); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index 536a5044e..10e1dedcb 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -188,7 +188,7 @@ public String call(String toolInput, ToolContext toolContext) { .writeValueAsString(datasourceExplorerService.execute(agentId, request, graphRequest)); } catch (Exception ex) { - throw new IllegalStateException("Datasource explorer tool failed: " + ex.getMessage(), ex); + throw new IllegalStateException("数据源探索工具执行失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java index 13ddc6291..981f85b56 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java @@ -29,10 +29,10 @@ public class DomainBusinessKnowledgeToolProvider implements AgentScopedToolProvi private static final String TOOL_NAME = "domain_business_knowledge.search"; private static final String DESCRIPTION = """ - Search the current agent's recalled business terms, FAQs, QA entries, and embedded documents on demand. - Use this tool when you need internal business definitions, metric rules, SOPs, historical cases, or terminology clarification. - Call it only when the answer depends on domain knowledge instead of general reasoning. - Do not use this tool for database table names, column names, field types, enum values, schema relations, field comments, or other table-structure interpretation questions. Those should go to datasource explorer first and semantic_model.search if supplemental semantic hints are needed. + 按需检索当前 Agent 已召回的业务术语、FAQ、问答条目和嵌入文档。 + 当回答依赖内部业务定义、指标口径、SOP、历史案例或领域术语澄清时,才使用本工具。 + 只有当答案确实依赖领域知识,而不是通用推理或数据库物理结构本身时,才调用本工具。 + 不要把本工具用于数据库表名、列名、字段类型、枚举值、schema 关系、字段注释或其他表结构解释问题;这些问题应先交给 datasource explorer,如仍需补充语义,再考虑 `semantic_model.search`。 """; private final DomainBusinessKnowledgeToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index 11c7abbee..e10d8dca0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -132,7 +132,7 @@ public String call(String toolInput, ToolContext toolContext) { .writeValueAsString(domainKnowledgeSearchService.search(agentId, request, graphRequest)); } catch (Exception ex) { - throw new IllegalStateException("Failed to search domain business knowledge: " + ex.getMessage(), ex); + throw new IllegalStateException("领域业务知识检索失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java index 209c7bb92..d81f7b6fe 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java @@ -54,7 +54,7 @@ public SemanticModelSearchResult search(String agentId, SemanticModelSearchReque @Nullable GraphRequest graphRequest) { if (!StringUtils.hasText(agentId)) { return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search requires a numeric agent id."); + "semantic_model.search 需要数值型 agentId 参数。"); } Long parsedAgentId; try { @@ -62,7 +62,7 @@ public SemanticModelSearchResult search(String agentId, SemanticModelSearchReque } catch (NumberFormatException ex) { return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search requires a numeric agent id."); + "semantic_model.search 需要数值型 agentId 参数。"); } return search(parsedAgentId, request, graphRequest); } @@ -75,17 +75,17 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest @Nullable GraphRequest graphRequest) { String query = request == null ? null : request.getQuery(); if (!StringUtils.hasText(query)) { - throw new IllegalArgumentException("query is required for semantic_model.search"); + throw new IllegalArgumentException("semantic_model.search 需要 query 参数"); } AgentDatasource activeDatasource = resolveActiveDatasource(agentId); if (activeDatasource == null || activeDatasource.getDatasourceId() == null) { - return emptyResult(query, "No active datasource is available for semantic_model.search."); + return emptyResult(query, "当前没有可用于 semantic_model.search 的活动数据源。"); } TableSearchScope scope = resolveTableSearchScope(activeDatasource, request == null ? null : request.getTableNames()); if (scope.isScoped() && CollectionUtils.isEmpty(scope.getTableNames())) { return emptyResult(query, - "Requested tables are outside the active datasource visibility scope for semantic_model.search."); + "请求中指定的表超出了当前活动数据源对 semantic_model.search 的可见范围。"); } List candidates = scope.isUnbounded() ? semanticModelService.getEnabledByAgentIdAndDatasourceId(agentId, activeDatasource.getDatasourceId()) @@ -93,7 +93,7 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest activeDatasource.getDatasourceId(), scope.getTableNames()); if (CollectionUtils.isEmpty(candidates)) { return emptyResult(query, - "No enabled semantic model entries matched this agent/table scope. Use datasource explorer for physical schema details."); + "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理 schema 请改用 datasource explorer 查看。"); } List scoredHits = candidates.stream() @@ -112,7 +112,7 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest if (scoredHits.isEmpty()) { return emptyResult(query, - "No supplemental semantic hints matched the query. If datasource explorer already answers the schema question, do not call semantic_model.search."); + "没有匹配到补充语义提示;如果 datasource explorer 已能回答 schema 问题,就不要额外调用 semantic_model.search。"); } List hits = scoredHits.stream().map(this::toHit).toList(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java index 5beb1cad7..dfcfdb0f9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java @@ -29,11 +29,11 @@ public class SemanticModelToolProvider implements AgentScopedToolProvider { private static final String TOOL_NAME = "semantic_model.search"; private static final String DESCRIPTION = """ - Supplemental semantic hint tool for table/column understanding only. - Use this tool when the user asks what a table/column means, asks for a business-friendly name, asks for enum meaning, asks for field usage notes, or asks for relation hints that are not explicitly stored in the physical schema. - Typical examples: "token名称类型", "status字段什么意思", "这个字段有哪些别名", "这两个表可能怎么关联". - Use datasource explorer first for physical schema, column lists, data types already in the database, previews, and readonly SQL. - Do not use this tool for SQL execution, schema discovery already covered by datasource explorer, or business definitions/metric rules/SOPs that belong to domain_business_knowledge.search. + 仅用于补充理解表和字段语义的辅助工具。 + 当用户在询问某张表或某个字段的含义、业务友好名称、枚举含义、字段使用备注,或数据库物理 schema 中未显式存储的关系提示时,才使用本工具。 + 典型问题包括:“token 名称类型”“status 字段什么意思”“这个字段有哪些别名”“这两个表可能怎么关联”。 + 数据库里的物理 schema、字段列表、字段类型、样例预览和只读 SQL,应优先使用 datasource explorer 获取。 + 不要把本工具用于 SQL 执行、datasource explorer 已能覆盖的 schema 探索,或属于 `domain_business_knowledge.search` 的业务定义、指标口径和 SOP 检索。 """; private final SemanticModelToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java index f15ccd7de..ccc4a8eb5 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java @@ -101,7 +101,7 @@ public String call(String toolInput, ToolContext toolContext) { .writeValueAsString(semanticModelSearchService.search(agentId, request, graphRequest)); } catch (Exception ex) { - throw new IllegalStateException("Failed to search semantic model hints: " + ex.getMessage(), ex); + throw new IllegalStateException("语义模型提示检索失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java index 7d70cdd15..63d7bea33 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java @@ -101,10 +101,10 @@ public String call(String toolInput) { .toString(); } catch (IllegalArgumentException ex) { - throw new IllegalStateException("Invalid timezone or time format: " + ex.getMessage(), ex); + throw new IllegalStateException("时区或时间格式无效:" + ex.getMessage(), ex); } catch (Exception ex) { - throw new IllegalStateException("Failed to get current time: " + ex.getMessage(), ex); + throw new IllegalStateException("获取当前时间失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java index d34858eda..e293375a1 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -38,7 +38,7 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { "action": { "type": "string", "enum": ["SQL_VERIFY", "DATA_PROFILE"], - "description": "可选。默认 SQL_VERIFY。SQL_VERIFY 用于候选 SQL 的结构与意图校验;DATA_PROFILE 用于查看字段值域、空值率、distinct、top values 与样例。" + "description": "可选。默认 SQL_VERIFY。SQL_VERIFY 用于校验候选 SQL 是否真正符合用户意图;DATA_PROFILE 仅用于在少量关键候选字段语义仍不明确,且这种不确定性会影响过滤、分组、排序、时间窗口或指标写法时,补充查看字段值分布。" }, "query": { "type": "string", @@ -57,11 +57,11 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { "items": { "type": "string" }, - "description": "DATA_PROFILE 时可选。要分析的字段列表;不传时默认取该表前几个可见字段。" + "description": "DATA_PROFILE 时可选。优先只传需要诊断的少量关键字段;不传时默认取该表前几个可见字段。" }, "limit": { "type": "integer", - "description": "DATA_PROFILE 时可选。样例值和 top values 的返回上限,默认 5,最大 20。" + "description": "DATA_PROFILE 时可选。样例值和高频值的返回上限,默认 5,最大 20。" }, "tableSchemas": { "type": "object", @@ -80,14 +80,14 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { """; private static final String DESCRIPTION = """ - Unified SQL guard tool for SQL-backed answers. - Action SQL_VERIFY: check whether the candidate SQL really matches the user's intent before execution or final answer. - Action DATA_PROFILE: inspect column value distribution only when a small set of candidate columns remain semantically ambiguous after schema inspection, and that ambiguity would materially change filters, grouping, ordering, time windows, or metric logic. - Do not call DATA_PROFILE as a default preflight step for every query. Skip it when the user request and schema already make the relevant columns obvious. - When using DATA_PROFILE, prefer focused columnNames instead of profiling an entire table. - For SQL_VERIFY, if verification fails, read isAligned=false plus problems, ruleChecks and fixSuggestions, then rewrite SQL yourself and call sql_guard.check again. - For DATA_PROFILE, use the returned columnProfiles to understand null ratio, distinct count, top values, samples, and whether a field looks categorical, numeric, or temporal. - Always pass fresh top-level parameters for the current action. Do not pass previous sql_guard.check output back into the tool. + 统一 SQL 守卫工具,供所有基于 SQL 的回答使用。 + 1. `action=SQL_VERIFY`:在执行 SQL 或基于 SQL 生成最终回答前,检查候选 SQL 是否真正符合用户意图。 + 2. `action=DATA_PROFILE`:只有在完成 schema 检查后,仍有少量关键候选字段语义不明确,且这种不确定性会实质影响过滤、分组、排序、时间窗口或指标写法时,才用于补充查看字段值分布。 + 3. 不要把 DATA_PROFILE 当作每次查询的默认前置步骤;如果用户问题、schema 和列名已经足够明确,就直接跳过。 + 4. 使用 DATA_PROFILE 时,优先传少量关键 `columnNames`,不要对整张表做无差别 profile。 + 5. 如果 SQL_VERIFY 返回 `isAligned=false`,请读取 `problems`、`ruleChecks` 和 `fixSuggestions`,自行改写 SQL 后再次调用 `sql_guard.check`。 + 6. 如果使用 DATA_PROFILE,请重点读取返回的 `columnProfiles`,理解空值率、去重计数、高频值、样例值,以及字段更像枚举、数值还是时间字段。 + 7. 每次调用都要传当前动作需要的最新顶层参数,不要把上一轮 `sql_guard.check` 的输出对象原样回传给工具。 """; private final ObjectMapper objectMapper; @@ -152,12 +152,12 @@ private String execute(String toolInput, ToolContext toolContext) { SqlGuardCheckResult result = switch (action) { case "DATA_PROFILE" -> sqlVerifyExplainService.inspectProfile(agentId, request); case "SQL_VERIFY" -> sqlVerifyExplainService.explain(request); - default -> throw new IllegalArgumentException("Unsupported sql_guard.check action: " + action); + default -> throw new IllegalArgumentException("不支持的 sql_guard.check 动作:" + action); }; return objectMapper.writeValueAsString(result); } catch (Exception ex) { - throw new IllegalStateException("Failed to execute sql_guard.check: " + ex.getMessage(), ex); + throw new IllegalStateException("sql_guard.check 执行失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java index 35e24f1dc..62dd32805 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -206,7 +206,7 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest request) { String tableName = StringUtils.trimToEmpty(request == null ? null : request.getTableName()); if (StringUtils.isBlank(tableName)) { - throw new IllegalArgumentException("sql_guard.check with action=DATA_PROFILE requires tableName"); + throw new IllegalArgumentException("sql_guard.check 在 action=DATA_PROFILE 时必须提供 tableName"); } ProfileContext context = resolveProfileContext(agentId); String actualTableName = resolveVisibleTableName(context, tableName); @@ -214,7 +214,7 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r List visibleColumns = applyVisibleColumnRestrictions(context, actualTableName, availableColumns); if (visibleColumns.isEmpty()) { throw new IllegalArgumentException( - "Table '%s' has no visible columns for current agent".formatted(actualTableName)); + "表 '%s' 在当前 Agent 下没有可见字段".formatted(actualTableName)); } List columnsToInspect = resolveColumnsToInspect(request, actualTableName, visibleColumns); int sampleLimit = normalizeProfileLimit(request == null ? null : request.getLimit()); @@ -223,7 +223,7 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r List> columnProfiles = columnsToInspect.stream() .map(column -> buildColumnProfile(context, actualTableName, column, totalRows, sampleLimit)) .toList(); - String summary = "Profiled %d columns from '%s' using visible columns only.".formatted(columnProfiles.size(), + String summary = "仅基于可见字段对表 '%s' 的 %d 个字段完成 profile 分析。".formatted(columnProfiles.size(), actualTableName); return SqlGuardCheckResult.builder() .action(ACTION_DATA_PROFILE) @@ -235,21 +235,21 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r .usedTables(List.of(actualTableName)) .columnProfiles(columnProfiles) .fixSuggestions( - List.of("Use categorical fields with concentrated topValues as filters or GROUP BY candidates.", - "Use numeric/date fields with min/max ranges as metric, trend, or time-window candidates.")) + List.of("可优先把高频值集中的分类字段用作过滤条件或 GROUP BY 候选字段。", + "可优先把具备 min/max 范围的数值或时间字段用作指标、趋势或时间窗口候选字段。")) .build(); } private ProfileContext resolveProfileContext(String agentId) { if (!StringUtils.isNumeric(agentId)) { - throw new IllegalArgumentException("sql_guard.check DATA_PROFILE only supports numeric agentId"); + throw new IllegalArgumentException("sql_guard.check 的 DATA_PROFILE 当前仅支持数值型 agentId"); } Long numericAgentId = Long.valueOf(agentId); AgentDatasource agentDatasource = agentDatasourceService.getCurrentAgentDatasource(numericAgentId); Datasource datasource = agentDatasource.getDatasource() != null ? agentDatasource.getDatasource() : datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); if (datasource == null) { - throw new IllegalStateException("Active datasource not found for agent " + agentId); + throw new IllegalStateException("当前 Agent 未找到活动数据源:" + agentId); } DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); Accessor accessor = accessorFactory.getAccessorByDbConfig(dbConfig); @@ -260,7 +260,7 @@ private ProfileContext resolveProfileContext(String agentId) { : explicitSelectedTables; } catch (Exception ex) { - throw new IllegalStateException("Failed to load visible tables for datasource %s: %s" + throw new IllegalStateException("加载数据源 %s 的可见表失败:%s" .formatted(datasource.getId(), ex.getMessage()), ex); } Map> visibleTablesByName = indexTables(visibleTables, false); @@ -289,7 +289,7 @@ private List loadTableColumns(ProfileContext context, String table } catch (Exception ex) { throw new IllegalStateException( - "Failed to load columns for table '%s': %s".formatted(tableName, ex.getMessage()), ex); + "加载表 '%s' 的字段失败:%s".formatted(tableName, ex.getMessage()), ex); } } @@ -321,7 +321,7 @@ private List resolveColumnsToInspect(SqlGuardCheckRequest request, for (String requestedColumn : requestedColumns) { ColumnInfoBO column = columnsByName.get(normalizeColumnName(requestedColumn)); if (column == null) { - throw new IllegalArgumentException("Column '%s' is not visible in table '%s' for current agent" + throw new IllegalArgumentException("字段 '%s' 在表 '%s' 中对当前 Agent 不可见" .formatted(requestedColumn, tableName)); } resolvedColumns.add(column); @@ -411,16 +411,16 @@ private List buildProfileHints(ColumnInfoBO column, Double nullRatio, Lo List> topValues) { List hints = new ArrayList<>(); if (Boolean.TRUE.equals(isLikelyCategorical(column, distinctCount, totalRows, topValues))) { - hints.add("Likely categorical field; suitable for filter or GROUP BY."); + hints.add("该字段很可能是枚举或分类字段,适合用于过滤条件或 GROUP BY。"); } if (supportsMinMax(column)) { - hints.add("Likely ordered field; suitable for range filter, metric, or trend axis."); + hints.add("该字段很可能具备顺序语义,适合用于范围过滤、指标计算或趋势轴。"); } if (nullRatio != null && nullRatio >= 0.5D) { - hints.add("High null ratio; be careful when using it as a hard filter."); + hints.add("该字段空值比例较高,作为强过滤条件时需要谨慎。"); } if (hints.isEmpty()) { - hints.add("Inspect samples and top values before deciding whether to use it in SQL."); + hints.add("请先结合样例值和高频值判断,再决定是否将该字段写入 SQL。"); } return hints; } @@ -478,7 +478,7 @@ private ResultSetBO executeSql(ProfileContext context, String sql) { return resultSet; } catch (Exception ex) { - throw new IllegalStateException("Failed to execute profile SQL: " + ex.getMessage(), ex); + throw new IllegalStateException("执行 profile SQL 失败:" + ex.getMessage(), ex); } } @@ -555,7 +555,7 @@ private long parseLong(String value) { private String resolveVisibleTableName(ProfileContext context, String tableName) { return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) .orElseThrow( - () -> new IllegalArgumentException("Table '%s' is not visible for current agent. Visible tables: %s" + () -> new IllegalArgumentException("表 '%s' 对当前 Agent 不可见。当前可见表:%s" .formatted(tableName, String.join(", ", context.visibleTables())))); } @@ -567,7 +567,7 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(exactMatches.get(0)); } if (exactMatches.size() > 1) { - throw new IllegalArgumentException("Table '%s' maps to multiple visible tables: %s".formatted(tableName, + throw new IllegalArgumentException("表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); } if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { @@ -578,7 +578,7 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(leafMatches.get(0)); } if (leafMatches.size() > 1) { - throw new IllegalArgumentException("Table '%s' is ambiguous across visible tables: %s".formatted(tableName, + throw new IllegalArgumentException("表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); } return Optional.empty(); From 16fa09bdaa5f1f71d759bf4eb3ff17eea447055a Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Tue, 28 Apr 2026 20:06:02 +0800 Subject: [PATCH 15/22] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f52b4e24d..6c7520620 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ data-agent-management/vectorstore/* # spring-ai-alibaba source spring-ai-alibaba-1.1.0.0/ .spec-workflow +.claude \ No newline at end of file From 9a2d7e7c69324813502ce25a8b6a66dfb232fd38 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Tue, 28 Apr 2026 22:21:57 +0800 Subject: [PATCH 16/22] feat: add harness tools --- .../impl/AiAgentRuntimeServiceImpl.java | 4 - .../dataagent/agentscope/tool/ToolError.java | 41 +++++ .../agentscope/tool/ToolErrorCode.java | 32 ++++ .../datasource/DatasourceExplorerAction.java | 2 +- .../datasource/DatasourceExplorerRequest.java | 3 - .../datasource/DatasourceExplorerResult.java | 21 +-- .../datasource/DatasourceExplorerService.java | 153 +----------------- .../DatasourceExplorerToolProvider.java | 79 +++++---- .../DomainBusinessKnowledgeToolProvider.java | 2 +- .../DomainBusinessKnowledgeToolSupport.java | 30 +++- .../semantic/SemanticModelSearchResult.java | 2 +- .../semantic/SemanticModelSearchService.java | 22 ++- .../semantic/SemanticModelToolProvider.java | 6 +- .../semantic/SemanticModelToolSupport.java | 28 +++- .../tool/sqlguard/SqlGuardCheckRequest.java | 7 - .../tool/sqlguard/SqlGuardCheckResult.java | 21 +-- .../tool/sqlguard/SqlGuardToolProvider.java | 57 ++++--- .../sqlguard/SqlVerifyExplainService.java | 60 +------ .../AnswerTraceExplainStore.java | 78 ++------- .../DomainKnowledgeSearchService.java | 8 +- .../DomainKnowledgeSearchServiceImpl.java | 10 +- .../src/main/resources/prompts/commonagent.md | 2 +- 22 files changed, 256 insertions(+), 412 deletions(-) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 879d52b24..e7577f445 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -337,10 +337,6 @@ private void mirrorExplainSummary(Span rootSpan, GraphRequest request) { if (StringUtils.hasText(summary.getDatasource())) { rootSpan.setAttribute("dataagent.answer.explain.datasource", summary.getDatasource()); } - if (summary.getUsedTables() != null && !summary.getUsedTables().isEmpty()) { - rootSpan.setAttribute("dataagent.answer.explain.used_tables", - String.join(",", summary.getUsedTables())); - } }); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java new file mode 100644 index 000000000..de59baf91 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java @@ -0,0 +1,41 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Builder; +import lombok.Value; + +@Value +@Builder +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolError { + + ToolErrorCode code; + + String message; + + Boolean retryable; + + public static ToolError of(ToolErrorCode code, String message) { + return ToolError.builder().code(code).message(message).build(); + } + + public static ToolError retryable(ToolErrorCode code, String message) { + return ToolError.builder().code(code).message(message).retryable(true).build(); + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java new file mode 100644 index 000000000..5e732b4f5 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool; + +public enum ToolErrorCode { + + INVALID_INPUT, + + UNSUPPORTED_ACTION, + + DATASOURCE_UNAVAILABLE, + + TABLE_NOT_VISIBLE, + + COLUMN_NOT_VISIBLE, + + EXECUTION_FAILED + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java index 18c473a9d..f338c1c3b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java @@ -40,7 +40,7 @@ public static DatasourceExplorerAction fromValue(String value) { return Arrays.stream(values()) .filter(action -> action.name().equalsIgnoreCase(value)) .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Unsupported datasource explorer action: " + value)); + .orElseThrow(() -> new IllegalArgumentException("不支持的数据源探索动作:" + value)); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java index 65ca23a79..eac1b7e4d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java @@ -15,7 +15,6 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.datasource; -import java.util.List; import lombok.Data; @Data @@ -27,8 +26,6 @@ public class DatasourceExplorerRequest { private String tableName; - private List tableNames; - private String sql; private Integer limit; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java index 1052386dd..b12857581 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java @@ -31,6 +31,8 @@ public class DatasourceExplorerResult { private String summary; + private Boolean searchReady; + @Builder.Default private List> tables = new ArrayList<>(); @@ -45,23 +47,4 @@ public class DatasourceExplorerResult { private String sql; - private String sqlExplanation; - - @Builder.Default - private List usedTables = new ArrayList<>(); - - @Builder.Default - private List usedColumns = new ArrayList<>(); - - @Builder.Default - private Map permissions = new java.util.LinkedHashMap<>(); - - @Builder.Default - private Map stats = new java.util.LinkedHashMap<>(); - - @Builder.Default - private List nextSuggestedActions = new ArrayList<>(); - - private boolean truncated; - } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index c4ad3d9aa..b8c0f7b1b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -35,7 +35,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; @@ -146,11 +145,7 @@ private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceE .toList(); return capture(baseResult(context, DatasourceExplorerAction.LIST_TABLES, "共发现 %d 张可见表".formatted(tables.size())) .tables(tables) - .usedTables(tables.stream().map(table -> String.valueOf(table.get("name"))).toList()) - .permissions(buildPermissionSummary(context, context.visibleTables())) - .stats(buildStatsSummary(0, limit, context.visibleTables().size() > limit, "not_collected")) - .nextSuggestedActions(List.of("get_table_schema", "find_tables", "search")) - .truncated(context.visibleTables().size() > limit) + .searchReady(!tables.isEmpty()) .build(), graphRequest); } @@ -169,12 +164,7 @@ private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceE String summary = query.isEmpty() ? "未提供筛选词,返回当前可见表列表" : "针对关键词“%s”匹配到 %d 张表".formatted(request.getQuery(), matchedTables.size()); return capture(baseResult(context, DatasourceExplorerAction.FIND_TABLES, summary).tables(matchedTables) - .usedTables(matchedTables.stream().map(table -> String.valueOf(table.get("name"))).toList()) - .permissions(buildPermissionSummary(context, - matchedTables.stream().map(table -> String.valueOf(table.get("name"))).toList())) - .stats(buildStatsSummary(0, limit, matchedTables.size() >= limit, "not_collected")) - .nextSuggestedActions(List.of("get_table_schema", "get_related_tables", "search")) - .truncated(matchedTables.size() >= limit) + .searchReady(!matchedTables.isEmpty()) .build(), graphRequest); } @@ -199,11 +189,7 @@ private DatasourceExplorerResult getTableSchema(ExplorerContext context, Datasou .tables(List.of(tableEntry)) .columns(columnEntries) .relations(relationEntries) - .usedTables(List.of(tableName)) - .usedColumns(listColumnNames(columnEntries)) - .permissions(buildPermissionSummary(context, List.of(tableName))) - .stats(buildStatsSummary(0, null, false, "not_collected")) - .nextSuggestedActions(List.of("search", "get_related_tables", "get_table_schema")) + .searchReady(true) .build(), graphRequest); } @@ -227,10 +213,7 @@ private DatasourceExplorerResult getRelatedTables(ExplorerContext context, Datas "表“%s”共找到 %d 张关联表".formatted(tableName, tableEntries.size())) .tables(tableEntries) .relations(relationEntries) - .usedTables(tableEntries.stream().map(table -> String.valueOf(table.get("name"))).toList()) - .permissions(buildPermissionSummary(context, relatedTables)) - .stats(buildStatsSummary(0, null, false, "not_collected")) - .nextSuggestedActions(List.of("get_table_schema", "search", "get_related_tables")) + .searchReady(!relationEntries.isEmpty()) .build(), graphRequest); } @@ -248,15 +231,7 @@ private DatasourceExplorerResult previewRows(ExplorerContext context, Datasource .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) .sql(sql) - .sqlExplanation(buildSqlExplanation(DatasourceExplorerAction.PREVIEW_ROWS, sql, List.of(tableName), - resultSet.getColumn(), limit, resultSet.getData().size(), resultSet.getData().size() >= limit)) - .usedTables(List.of(tableName)) - .usedColumns(resultSet.getColumn()) - .permissions(buildPermissionSummary(context, List.of(tableName))) - .stats(buildStatsSummary(resultSet.getData().size(), limit, resultSet.getData().size() >= limit, - "not_collected")) - .nextSuggestedActions(List.of("get_table_schema", "search")) - .truncated(resultSet.getData().size() >= limit) + .searchReady(true) .build(), graphRequest); } @@ -274,17 +249,7 @@ private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplo .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) .sql(guardedQuery.sql()) - .sqlExplanation(buildSqlExplanation(DatasourceExplorerAction.SEARCH, guardedQuery.sql(), - guardedQuery.referencedTables().stream().toList(), resultSet.getColumn(), limit, - resultSet.getData().size(), resultSet.getData().size() >= limit)) - .usedTables(guardedQuery.referencedTables().stream().toList()) - .usedColumns(guardedQuery.allowedResultHeaders() == null || guardedQuery.allowedResultHeaders().isEmpty() - ? resultSet.getColumn() : guardedQuery.allowedResultHeaders().stream().toList()) - .permissions(buildPermissionSummary(context, guardedQuery.referencedTables())) - .stats(buildStatsSummary(resultSet.getData().size(), limit, resultSet.getData().size() >= limit, - "not_collected")) - .nextSuggestedActions(List.of("get_table_schema", "find_tables", "search")) - .truncated(resultSet.getData().size() >= limit) + .searchReady(true) .build(), graphRequest); } @@ -428,9 +393,6 @@ private String requireSingleTableName(DatasourceExplorerRequest request) { if (StringUtils.isNotBlank(request.getTableName())) { return request.getTableName().trim(); } - if (request.getTableNames() != null && request.getTableNames().size() == 1) { - return StringUtils.trimToEmpty(request.getTableNames().get(0)); - } throw new IllegalArgumentException("当前 action 必须提供 tableName"); } @@ -912,109 +874,6 @@ private DatasourceExplorerResult capture(DatasourceExplorerResult result, @Nulla return result; } - private Map buildPermissionSummary(ExplorerContext context, Collection relevantTables) { - Map permissions = new LinkedHashMap<>(); - permissions.put("tableScopeMode", - context.explicitSelectedTables().isEmpty() ? "datasource_visible_tables" : "agent_selected_tables"); - if (!context.explicitSelectedTables().isEmpty()) { - permissions.put("selectedTables", context.explicitSelectedTables()); - } - List resolvedTables = Optional.ofNullable(relevantTables) - .orElse(List.of()) - .stream() - .filter(StringUtils::isNotBlank) - .map(String::trim) - .distinct() - .toList(); - if (!resolvedTables.isEmpty()) { - permissions.put("relevantTables", resolvedTables); - } - List> columnRestrictions = resolvedTables.stream().map(tableName -> { - List visibleColumns = context.visibleColumnsByTable().get(normalizeTableName(tableName)); - if (visibleColumns == null) { - return null; - } - Map restriction = new LinkedHashMap<>(); - restriction.put("tableName", tableName); - restriction.put("allowedColumns", visibleColumns); - return restriction; - }).filter(Objects::nonNull).toList(); - if (!columnRestrictions.isEmpty()) { - permissions.put("columnRestrictions", columnRestrictions); - } - return permissions; - } - - private Map buildStatsSummary(int returnedRows, Integer limit, boolean truncated, - String scanStatus) { - Map stats = new LinkedHashMap<>(); - stats.put("returnedRows", returnedRows); - if (limit != null) { - stats.put("limitApplied", limit); - } - stats.put("truncated", truncated); - stats.put("rowScanStatus", scanStatus); - stats.put("rowsScanned", null); - return stats; - } - - private String buildSqlExplanation(DatasourceExplorerAction action, String sql, Collection usedTables, - Collection usedColumns, Integer limit, int returnedRows, boolean truncated) { - if (!StringUtils.isNotBlank(sql)) { - return null; - } - List clauses = new ArrayList<>(); - List tables = Optional.ofNullable(usedTables) - .orElse(List.of()) - .stream() - .filter(StringUtils::isNotBlank) - .toList(); - if (!tables.isEmpty()) { - clauses.add("查询的数据来源是 " + String.join("、", tables)); - } - if (WHERE_PATTERN.matcher(sql).find()) { - clauses.add("SQL 带有 WHERE 条件,会先筛选符合条件的数据"); - } - if (GROUP_BY_PATTERN.matcher(sql).find()) { - clauses.add("SQL 带有 GROUP BY,会按维度分组后再做统计"); - } - if (ORDER_BY_PATTERN.matcher(sql).find()) { - clauses.add("SQL 带有 ORDER BY,会对结果做排序"); - } - if (hasLimit(sql) || limit != null) { - clauses.add("结果条数受 LIMIT/TOP 之类的上限控制"); - } - List columns = Optional.ofNullable(usedColumns) - .orElse(List.of()) - .stream() - .filter(StringUtils::isNotBlank) - .limit(8) - .toList(); - if (!columns.isEmpty()) { - clauses.add("最终展示的字段主要是 " + String.join("、", columns)); - } - clauses.add("本次返回了 %d 行结果".formatted(returnedRows)); - if (truncated) { - clauses.add("由于命中了返回上限,当前看到的可能只是部分结果"); - } - if (action == DatasourceExplorerAction.PREVIEW_ROWS) { - clauses.add("这条 SQL 主要用于预览样例数据,不是正式分析统计"); - } - return String.join("。", clauses) + "。"; - } - - private List listColumnNames(List> columnEntries) { - if (columnEntries == null) { - return List.of(); - } - return columnEntries.stream() - .map(entry -> entry.get("name")) - .filter(Objects::nonNull) - .map(String::valueOf) - .filter(StringUtils::isNotBlank) - .toList(); - } - private record ExplorerContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, Accessor accessor, List visibleTables, Set visibleTableNameSet, Map> visibleTablesByName, Map> visibleTablesByLeafName, diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index 10e1dedcb..5392af374 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -18,6 +18,8 @@ import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.Datasource; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; @@ -50,35 +52,26 @@ public class DatasourceExplorerToolProvider implements AgentScopedToolProvider { "PREVIEW_ROWS", "SEARCH" ], - "description": "探索动作。默认顺序:LIST_TABLES/FIND_TABLES -> GET_TABLE_SCHEMA -> GET_RELATED_TABLES -> SEARCH。PREVIEW_ROWS 仅在用户明确要求看样例行,或关键字段语义仍不确定且会影响 SQL 写法时才使用" + "description": "探索动作。" }, "query": { "type": "string", - "description": "用于 FIND_TABLES 的检索关键词" + "description": "FIND_TABLES 时必填,用于按关键词查找表。" }, "tableName": { "type": "string", - "description": "目标表名。GET_TABLE_SCHEMA / GET_RELATED_TABLES / PREVIEW_ROWS 必填。不要把 PREVIEW_ROWS 当作默认前置动作" - }, - "tableNames": { - "type": "array", - "items": { - "type": "string" - }, - "description": "可选表名列表。当前版本主要使用单表" + "description": "GET_TABLE_SCHEMA、GET_RELATED_TABLES、PREVIEW_ROWS 时必填的目标表名。" }, "sql": { "type": "string", - "description": "SEARCH 动作需要的只读 SQL。仅允许 SELECT/WITH" + "description": "SEARCH 时必填的只读 SQL,仅允许 SELECT/WITH。" }, "limit": { "type": "integer", - "description": "返回行数上限,默认 20,最大 200" + "description": "可选返回上限,默认 20,最大 200。" } }, - "required": [ - "action" - ] + "required": ["action"] } """; @@ -103,7 +96,7 @@ public Map getToolCallbacks(String agentId) { .inputSchema(INPUT_SCHEMA) .build(); return Map.of(toolName, new AgentBoundDatasourceExplorerToolCallback(agentId, toolDefinition, - datasourceExplorerService, objectMapper)); + datasourceExplorerService, objectMapper)); } private AgentDatasource resolveActiveDatasource(String agentId) { @@ -135,18 +128,14 @@ private String buildDescription(Datasource datasource, AgentDatasource agentData .formatted(selectedTables.size(), String.join(", ", selectedTables.stream().limit(8).toList())); return """ 数据源'%s'(%s)的统一探索工具。 - 可用于查看表列表、查看表结构、查看统一关系、按需预览样例数据,以及执行只读 SQL 查询。 + 可用于查看表列表、查找表、查看单表结构、查看关系、按需预览样例数据,以及执行只读 SQL 查询。 约束说明: 1. 只能访问当前 Agent 的活动数据源。 2. SEARCH 仅允许执行只读 SQL。 - 3. GET_TABLE_SCHEMA 和 GET_RELATED_TABLES 返回的 relations 字段,会合并数据库物理外键与已配置的逻辑关系。 - 4. 做表关系推断和 Join 规划时,应优先使用 relations 字段。 - 5. 表元数据里的 foreignKeys 字段仅为兼容保留,Agent 推理时优先使用 relations。 - 6. 默认调用顺序:LIST_TABLES/FIND_TABLES -> GET_TABLE_SCHEMA -> GET_RELATED_TABLES -> SEARCH。 - 7. PREVIEW_ROWS 只是条件动作,不是默认前置步骤。只有在用户明确要求看样例行,或 schema、列名和现有语义仍不足以判断关键字段实际语义,且这种不确定性会影响 SQL 写法时,才调用 PREVIEW_ROWS。 - 8. 如果 schema、relations、列名和已知语义已经足够支持写 SQL,就直接进入 SQL_VERIFY/SEARCH,不要为了“先确认数据质量”再额外预览样例。 - 9. 不要根据可见值推断隐藏字段。例如不要从邮箱前缀、ID、编码或别名推断用户名或真实姓名。 - 10. %s + 3. 如果只需要定位表,优先使用 LIST_TABLES 或 FIND_TABLES。 + 4. 如果需要写 SQL,先获取表结构和关系,再决定是否执行 SEARCH。 + 5. PREVIEW_ROWS 不是默认前置动作,只有样例值会实质影响 SQL 写法时才使用。 + 6. %s """ .formatted(datasource.getName(), datasource.getType(), visibleTables); } @@ -183,12 +172,46 @@ public String call(String toolInput) { public String call(String toolInput, ToolContext toolContext) { try { DatasourceExplorerRequest request = objectMapper.readValue(toolInput, DatasourceExplorerRequest.class); + validateRequest(request); GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper - .writeValueAsString(datasourceExplorerService.execute(agentId, request, graphRequest)); + return objectMapper.writeValueAsString(datasourceExplorerService.execute(agentId, request, graphRequest)); + } + catch (Exception ex) { + throw new IllegalStateException( + objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "数据源探索工具执行失败:" + ex.getMessage())), + ex); + } + } + + private void validateRequest(DatasourceExplorerRequest request) { + if (request == null || request.getAction() == null) { + throw new IllegalArgumentException( + objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, "数据源探索工具需要 action 参数"))); + } + switch (request.getAction()) { + case FIND_TABLES -> requireText(request.getQuery(), "FIND_TABLES 需要 query 参数"); + case GET_TABLE_SCHEMA, GET_RELATED_TABLES, PREVIEW_ROWS -> + requireText(request.getTableName(), request.getAction().name() + " 需要 tableName 参数"); + case SEARCH -> requireText(request.getSql(), "SEARCH 需要 sql 参数"); + case LIST_TABLES -> { + } + default -> throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.UNSUPPORTED_ACTION, "不支持的数据源探索动作:" + request.getAction()))); + } + } + + private void requireText(String value, String message) { + if (!StringUtils.isNotBlank(value)) { + throw new IllegalArgumentException(objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, message))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); } catch (Exception ex) { - throw new IllegalStateException("数据源探索工具执行失败:" + ex.getMessage(), ex); + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java index 981f85b56..7c2cc99bc 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java @@ -32,7 +32,7 @@ public class DomainBusinessKnowledgeToolProvider implements AgentScopedToolProvi 按需检索当前 Agent 已召回的业务术语、FAQ、问答条目和嵌入文档。 当回答依赖内部业务定义、指标口径、SOP、历史案例或领域术语澄清时,才使用本工具。 只有当答案确实依赖领域知识,而不是通用推理或数据库物理结构本身时,才调用本工具。 - 不要把本工具用于数据库表名、列名、字段类型、枚举值、schema 关系、字段注释或其他表结构解释问题;这些问题应先交给 datasource explorer,如仍需补充语义,再考虑 `semantic_model.search`。 + 不要把本工具用于数据库表名、列名、字段类型、枚举值、表关系、字段注释或其他表结构解释问题;这些问题应先交给数据源探索工具,如仍需补充语义,再考虑 `semantic_model.search`。 """; private final DomainBusinessKnowledgeToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index e10d8dca0..431161fd4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -17,6 +17,8 @@ import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchRequest; import com.fasterxml.jackson.databind.JsonNode; @@ -110,6 +112,7 @@ public String call(String toolInput, ToolContext toolContext) { JsonNode jsonNode = StringUtils.hasText(toolInput) ? objectMapper.readTree(toolInput) : objectMapper.createObjectNode(); String query = jsonNode.path("query").asText(""); + validateQuery(query); List knowledgeTypes = new ArrayList<>(); JsonNode knowledgeTypesNode = jsonNode.path("knowledgeTypes"); if (knowledgeTypesNode.isArray()) { @@ -122,17 +125,34 @@ public String call(String toolInput, ToolContext toolContext) { Integer topK = jsonNode.has("topK") && jsonNode.get("topK").canConvertToInt() ? jsonNode.get("topK").asInt() : null; Double similarityThreshold = jsonNode.has("similarityThreshold") - && jsonNode.get("similarityThreshold").isNumber() - ? jsonNode.get("similarityThreshold").asDouble() : null; + && jsonNode.get("similarityThreshold").isNumber() ? jsonNode.get("similarityThreshold").asDouble() + : null; DomainKnowledgeSearchRequest request = new DomainKnowledgeSearchRequest(query, knowledgeTypes.isEmpty() ? null : List.copyOf(knowledgeTypes), topK, similarityThreshold); GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper - .writeValueAsString(domainKnowledgeSearchService.search(agentId, request, graphRequest)); + return objectMapper.writeValueAsString(domainKnowledgeSearchService.search(agentId, request, graphRequest)); } catch (Exception ex) { - throw new IllegalStateException("领域业务知识检索失败:" + ex.getMessage(), ex); + throw new IllegalStateException(objectToJson( + ToolError.of(ToolErrorCode.EXECUTION_FAILED, "domain_business_knowledge.search 执行失败:" + ex.getMessage())), + ex); + } + } + + private void validateQuery(String query) { + if (!StringUtils.hasText(query)) { + throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.INVALID_INPUT, "domain_business_knowledge.search 需要 query 参数"))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java index c1212d3ac..4624590b4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java @@ -24,7 +24,7 @@ @Builder public class SemanticModelSearchResult { - private String query; + private String resolution; private String summary; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java index d81f7b6fe..6cd2efd68 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java @@ -93,7 +93,7 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest activeDatasource.getDatasourceId(), scope.getTableNames()); if (CollectionUtils.isEmpty(candidates)) { return emptyResult(query, - "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理 schema 请改用 datasource explorer 查看。"); + "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理表结构请改用数据源探索工具查看。"); } List scoredHits = candidates.stream() @@ -112,29 +112,25 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest if (scoredHits.isEmpty()) { return emptyResult(query, - "没有匹配到补充语义提示;如果 datasource explorer 已能回答 schema 问题,就不要额外调用 semantic_model.search。"); + "没有匹配到补充语义提示;如果数据源探索工具已能回答物理表结构问题,就不要额外调用 semantic_model.search。"); } List hits = scoredHits.stream().map(this::toHit).toList(); + String summary = "共匹配到 %d 条补充语义提示。这些结果只用于补充理解表和字段语义,不能替代数据源探索工具的物理结构探索。" + .formatted(hits.size()); if (graphRequest != null) { answerTraceExplainStore.recordSemanticSearch(graphRequest, query, - "Found %d supplemental semantic hints".formatted(hits.size()), hits); + "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); } else { answerTraceExplainStore.recordSemanticSearch(query, - "Found %d supplemental semantic hints".formatted(hits.size()), hits); - } - return SemanticModelSearchResult.builder() - .query(query) - .summary( - "Found %d supplemental semantic hints. These are auxiliary explanations for table/column understanding, not a replacement for datasource exploration." - .formatted(hits.size())) - .hits(hits) - .build(); + "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); + } + return SemanticModelSearchResult.builder().summary(summary).hits(hits).resolution("matched").build(); } private SemanticModelSearchResult emptyResult(String query, String summary) { - return SemanticModelSearchResult.builder().query(query).summary(summary).build(); + return SemanticModelSearchResult.builder().summary(summary).resolution("no_match").build(); } private AgentDatasource resolveActiveDatasource(Long agentId) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java index dfcfdb0f9..dbd85e1d0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java @@ -30,10 +30,10 @@ public class SemanticModelToolProvider implements AgentScopedToolProvider { private static final String DESCRIPTION = """ 仅用于补充理解表和字段语义的辅助工具。 - 当用户在询问某张表或某个字段的含义、业务友好名称、枚举含义、字段使用备注,或数据库物理 schema 中未显式存储的关系提示时,才使用本工具。 + 当用户在询问某张表或某个字段的含义、业务友好名称、枚举含义、字段使用备注,或数据库物理表结构中未显式存储的关系提示时,才使用本工具。 典型问题包括:“token 名称类型”“status 字段什么意思”“这个字段有哪些别名”“这两个表可能怎么关联”。 - 数据库里的物理 schema、字段列表、字段类型、样例预览和只读 SQL,应优先使用 datasource explorer 获取。 - 不要把本工具用于 SQL 执行、datasource explorer 已能覆盖的 schema 探索,或属于 `domain_business_knowledge.search` 的业务定义、指标口径和 SOP 检索。 + 数据库里的物理表结构、字段列表、字段类型、样例预览和只读 SQL,应优先使用数据源探索工具获取。 + 不要把本工具用于 SQL 执行、数据源探索工具已能覆盖的表结构探索,或属于 `domain_business_knowledge.search` 的业务定义、指标口径和 SOP 检索。 """; private final SemanticModelToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java index ccc4a8eb5..e1eda9b31 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java @@ -17,6 +17,8 @@ import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.model.ToolContext; @@ -39,7 +41,7 @@ public class SemanticModelToolSupport { }, "tableNames": { "type": "array", - "description": "可选。将检索范围限制在这些表内;如果 datasource explorer 已能定位表结构,则不必传该工具。", + "description": "可选。将检索范围限制在这些表内;如果数据源探索工具已能定位表结构,则不必传该工具。", "items": { "type": "string" } @@ -96,12 +98,30 @@ public String call(String toolInput, ToolContext toolContext) { SemanticModelSearchRequest request = StringUtils.hasText(toolInput) ? objectMapper.readValue(toolInput, SemanticModelSearchRequest.class) : new SemanticModelSearchRequest(); + validateRequest(request); GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper - .writeValueAsString(semanticModelSearchService.search(agentId, request, graphRequest)); + return objectMapper.writeValueAsString(semanticModelSearchService.search(agentId, request, graphRequest)); } catch (Exception ex) { - throw new IllegalStateException("语义模型提示检索失败:" + ex.getMessage(), ex); + throw new IllegalStateException( + objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "semantic_model.search 执行失败:" + ex.getMessage())), + ex); + } + } + + private void validateRequest(SemanticModelSearchRequest request) { + if (request == null || !StringUtils.hasText(request.getQuery())) { + throw new IllegalArgumentException( + objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, "semantic_model.search 需要 query 参数"))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java index bb1d2a156..1b508dd97 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java @@ -15,7 +15,6 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; -import com.fasterxml.jackson.databind.JsonNode; import java.util.List; import lombok.Data; import org.apache.commons.lang3.StringUtils; @@ -37,12 +36,6 @@ class SqlGuardCheckRequest { private Integer limit; - private JsonNode tableSchemas; - - private JsonNode semanticHits; - - private JsonNode businessKnowledgeHits; - String normalizedAction() { return StringUtils.defaultIfBlank(action, "SQL_VERIFY").trim().toUpperCase(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java index 7f25716e2..81e6e868f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java @@ -17,49 +17,34 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import lombok.Builder; -import lombok.Data; - import java.util.ArrayList; import java.util.List; import java.util.Map; +import lombok.Builder; +import lombok.Data; @Data @Builder @JsonInclude(JsonInclude.Include.NON_NULL) class SqlGuardCheckResult { - private String action; - - private String query; - - private String sql; + private String decision; private String tableName; private String summary; - private String explainedIntent; - @JsonProperty("isAligned") private Boolean isAligned; private Long totalRows; - private Integer inspectedColumnCount; - @Builder.Default private List problems = new ArrayList<>(); @Builder.Default private List fixSuggestions = new ArrayList<>(); - @Builder.Default - private List usedTables = new ArrayList<>(); - - @Builder.Default - private List usedMetrics = new ArrayList<>(); - @Builder.Default private List ruleChecks = new ArrayList<>(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java index e293375a1..4ba1b3c2b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -18,6 +18,8 @@ import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Map; import org.springframework.ai.chat.model.ToolContext; @@ -38,7 +40,7 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { "action": { "type": "string", "enum": ["SQL_VERIFY", "DATA_PROFILE"], - "description": "可选。默认 SQL_VERIFY。SQL_VERIFY 用于校验候选 SQL 是否真正符合用户意图;DATA_PROFILE 仅用于在少量关键候选字段语义仍不明确,且这种不确定性会影响过滤、分组、排序、时间窗口或指标写法时,补充查看字段值分布。" + "description": "可选。默认 SQL_VERIFY。" }, "query": { "type": "string", @@ -46,7 +48,7 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { }, "sql": { "type": "string", - "description": "SQL_VERIFY 时必填。当前准备执行或准备返回给用户的候选 SQL。" + "description": "SQL_VERIFY 时必填。待校验 SQL。" }, "tableName": { "type": "string", @@ -57,23 +59,11 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { "items": { "type": "string" }, - "description": "DATA_PROFILE 时可选。优先只传需要诊断的少量关键字段;不传时默认取该表前几个可见字段。" + "description": "DATA_PROFILE 时可选。优先只传少量关键字段。" }, "limit": { "type": "integer", - "description": "DATA_PROFILE 时可选。样例值和高频值的返回上限,默认 5,最大 20。" - }, - "tableSchemas": { - "type": "object", - "description": "可选。把 datasource explorer 的 schema 结果原样传入,帮助 SQL 校验识别时间列、维度列与表关系。" - }, - "semanticHits": { - "type": "object", - "description": "可选。把 semantic_model.search 的结果原样传入。" - }, - "businessKnowledgeHits": { - "type": "object", - "description": "可选。把 domain_business_knowledge.search 的结果原样传入。" + "description": "DATA_PROFILE 时可选。返回上限,默认 5,最大 20。" } } } @@ -87,7 +77,6 @@ public class SqlGuardToolProvider implements AgentScopedToolProvider { 4. 使用 DATA_PROFILE 时,优先传少量关键 `columnNames`,不要对整张表做无差别 profile。 5. 如果 SQL_VERIFY 返回 `isAligned=false`,请读取 `problems`、`ruleChecks` 和 `fixSuggestions`,自行改写 SQL 后再次调用 `sql_guard.check`。 6. 如果使用 DATA_PROFILE,请重点读取返回的 `columnProfiles`,理解空值率、去重计数、高频值、样例值,以及字段更像枚举、数值还是时间字段。 - 7. 每次调用都要传当前动作需要的最新顶层参数,不要把上一轮 `sql_guard.check` 的输出对象原样回传给工具。 """; private final ObjectMapper objectMapper; @@ -149,15 +138,36 @@ private String execute(String toolInput, ToolContext toolContext) { ? objectMapper.readValue(toolInput, SqlGuardCheckRequest.class) : new SqlGuardCheckRequest(); enrichRequestFromToolContext(request, toolContext); String action = request.normalizedAction(); + validateRequest(request, action); SqlGuardCheckResult result = switch (action) { case "DATA_PROFILE" -> sqlVerifyExplainService.inspectProfile(agentId, request); case "SQL_VERIFY" -> sqlVerifyExplainService.explain(request); - default -> throw new IllegalArgumentException("不支持的 sql_guard.check 动作:" + action); + default -> throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.UNSUPPORTED_ACTION, "不支持的 sql_guard.check 动作:" + action))); }; return objectMapper.writeValueAsString(result); } catch (Exception ex) { - throw new IllegalStateException("sql_guard.check 执行失败:" + ex.getMessage(), ex); + throw new IllegalStateException( + objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "sql_guard.check 执行失败:" + ex.getMessage())), + ex); + } + } + + private void validateRequest(SqlGuardCheckRequest request, String action) { + if ("DATA_PROFILE".equals(action)) { + requireText(request.getTableName(), "DATA_PROFILE 需要 tableName 参数"); + return; + } + if ("SQL_VERIFY".equals(action)) { + requireText(request.getQuery(), "SQL_VERIFY 需要 query 参数"); + requireText(request.getSql(), "SQL_VERIFY 需要 sql 参数"); + } + } + + private void requireText(String value, String message) { + if (!StringUtils.hasText(value)) { + throw new IllegalArgumentException(objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, message))); } } @@ -172,6 +182,15 @@ private void enrichRequestFromToolContext(SqlGuardCheckRequest request, ToolCont request.setHumanFeedbackContent(graphRequest.getHumanFeedbackContent()); } + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; + } + } + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java index 62dd32805..66f0c28c9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -26,7 +26,6 @@ import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; import com.alibaba.cloud.ai.dataagent.util.SqlUtil; -import com.fasterxml.jackson.databind.JsonNode; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -138,12 +137,9 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { } catch (IllegalArgumentException ex) { return SqlGuardCheckResult.builder() - .action(ACTION_SQL_VERIFY) - .query(query) - .sql(sql) + .decision("revise_sql") .isAligned(false) .summary("SQL 无法通过语法解析,无法继续做结构和意图一致性校验。") - .explainedIntent(buildIntentExplanation(intent)) .problems(List.of(SqlGuardProblem.builder() .code("SQL_PARSE_ERROR") .title("SQL 语法解析失败") @@ -189,16 +185,11 @@ public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { fixSuggestions.add("当前规则校验通过;如要进一步提高置信度,可继续核对执行结果与最终答案解释。"); } return SqlGuardCheckResult.builder() - .action(ACTION_SQL_VERIFY) - .query(query) - .sql(sql) + .decision(aligned ? "safe_to_execute" : "revise_sql") .isAligned(aligned) .summary(summary) - .explainedIntent(buildIntentExplanation(intent)) .problems(problems) .fixSuggestions(List.copyOf(fixSuggestions)) - .usedTables(shape.usedTables()) - .usedMetrics(shape.usedMetrics()) .ruleChecks(ruleChecks) .build(); } @@ -226,13 +217,10 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r String summary = "仅基于可见字段对表 '%s' 的 %d 个字段完成 profile 分析。".formatted(columnProfiles.size(), actualTableName); return SqlGuardCheckResult.builder() - .action(ACTION_DATA_PROFILE) - .query(request == null ? null : request.getQuery()) + .decision("inspect_columns") .tableName(actualTableName) .summary(summary) .totalRows(totalRows) - .inspectedColumnCount(columnProfiles.size()) - .usedTables(List.of(actualTableName)) .columnProfiles(columnProfiles) .fixSuggestions( List.of("可优先把高频值集中的分类字段用作过滤条件或 GROUP BY 候选字段。", @@ -1110,47 +1098,7 @@ private boolean detectTimeBucket(String normalizedSql, Set knownTimeColu } private Set extractKnownTimeColumns(SqlGuardCheckRequest request) { - Set columns = new LinkedHashSet<>(); - if (request == null) { - return columns; - } - collectTimeColumns(request.getTableSchemas(), columns); - collectTimeColumns(request.getSemanticHits(), columns); - collectTimeColumns(request.getBusinessKnowledgeHits(), columns); - return columns; - } - - private void collectTimeColumns(JsonNode node, Set columns) { - if (node == null || node.isNull()) { - return; - } - if (node.isTextual()) { - String value = node.asText(); - if (isLikelyTimeColumn(value)) { - columns.add(value.toLowerCase(Locale.ROOT)); - } - return; - } - if (node.isArray()) { - for (JsonNode item : node) { - collectTimeColumns(item, columns); - } - return; - } - if (!node.isObject()) { - return; - } - node.fields().forEachRemaining(entry -> { - String fieldName = entry.getKey(); - JsonNode value = entry.getValue(); - if (value != null && value.isTextual() - && ("name".equalsIgnoreCase(fieldName) || "columnName".equalsIgnoreCase(fieldName) - || "fieldName".equalsIgnoreCase(fieldName) || "column".equalsIgnoreCase(fieldName)) - && isLikelyTimeColumn(value.asText())) { - columns.add(value.asText().toLowerCase(Locale.ROOT)); - } - collectTimeColumns(value, columns); - }); + return new LinkedHashSet<>(); } private boolean isLikelyTimeColumn(String value) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java index e12ad0d41..5b087b85b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java @@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import java.time.Instant; import java.util.ArrayList; -import java.util.Collection; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; @@ -149,7 +148,6 @@ public Optional getExplain(String sessionId, String runt public Optional getMirrorSummary(String sessionId, String runtimeRequestId) { return getExplain(sessionId, runtimeRequestId).map(explain -> ExplainMirrorSummary.builder() .datasource(explain.getDatasource()) - .usedTables(explain.getUsedTables()) .semanticHitCount(explain.getSemanticHits() == null ? 0 : explain.getSemanticHits().size()) .knowledgeHitCount(explain.getKnowledgeHits() == null ? 0 : explain.getKnowledgeHits().size()) .toolStepCount(explain.getToolSteps() == null ? 0 : explain.getToolSteps().size()) @@ -237,7 +235,7 @@ private void applyKnowledgeSearch(ExplainAssembly assembly, DomainKnowledgeSearc .toolName("domain_business_knowledge.search") .title("业务知识检索") .summary("检索到 %d 条知识命中".formatted(result.hits() == null ? 0 : result.hits().size())) - .detail(result.query()) + .detail(result.resolution()) .timestampEpochMs(Instant.now().toEpochMilli()) .build()); if (result.hits() != null) { @@ -270,13 +268,6 @@ private void applyDatasourceResult(ExplainAssembly assembly, DatasourceExplorerR if (StringUtils.hasText(result.getSql())) { assembly.sql = result.getSql(); } - if (StringUtils.hasText(result.getSqlExplanation())) { - assembly.sqlExplanation = result.getSqlExplanation(); - } - mergeOrdered(assembly.usedTables, result.getUsedTables()); - mergeOrdered(assembly.usedColumns, result.getUsedColumns()); - mergeMap(assembly.permissions, result.getPermissions()); - mergeMap(assembly.stats, result.getStats()); assembly.toolSteps.add(ToolStepView.builder() .toolName("datasource.explorer") .title(result.getAction()) @@ -289,13 +280,13 @@ private void applyDatasourceResult(ExplainAssembly assembly, DatasourceExplorerR } private void applyClarifyAssessment(ExplainAssembly assembly, QueryClarifyAssessment assessment) { - assembly.stats.put("riskLevel", assessment.riskLevel().value()); - assembly.stats.put("clarifyRequired", assessment.clarifyRequired()); - assembly.stats.put("missingDimensions", assessment.missingDimensions()); - assembly.stats.put("followUpQuestions", assessment.followUpQuestions()); - assembly.stats.put("suggestedAssumptions", assessment.suggestedAssumptions()); + assembly.clarify.put("riskLevel", assessment.riskLevel().value()); + assembly.clarify.put("clarifyRequired", assessment.clarifyRequired()); + assembly.clarify.put("missingDimensions", assessment.missingDimensions()); + assembly.clarify.put("followUpQuestions", assessment.followUpQuestions()); + assembly.clarify.put("suggestedAssumptions", assessment.suggestedAssumptions()); if (StringUtils.hasText(assessment.feedbackContent())) { - assembly.stats.put("humanFeedbackContent", assessment.feedbackContent()); + assembly.clarify.put("humanFeedbackContent", assessment.feedbackContent()); } assembly.toolSteps.add(ToolStepView.builder() .toolName("query_clarify.check") @@ -310,29 +301,6 @@ private void applyClarifyAssessment(ExplainAssembly assembly, QueryClarifyAssess assembly.updatedAt = Instant.now().toEpochMilli(); } - private void mergeOrdered(Set target, Collection source) { - if (source == null) { - return; - } - for (String item : source) { - if (StringUtils.hasText(item)) { - target.add(item.trim()); - } - } - } - - private void mergeMap(Map target, Map source) { - if (source == null) { - return; - } - source.forEach((key, value) -> { - if (!StringUtils.hasText(key) || value == null) { - return; - } - target.put(key, value); - }); - } - private void evictOverflowLocked() { while (explainsBySession.size() > MAX_SESSION_COUNT) { String eldestSessionId = explainsBySession.keySet().iterator().next(); @@ -367,21 +335,13 @@ private static final class ExplainAssembly { private String sql; - private String sqlExplanation; - private final List semanticHits = new ArrayList<>(); private final List knowledgeHits = new ArrayList<>(); private final List toolSteps = new ArrayList<>(); - private final Set usedTables = new LinkedHashSet<>(); - - private final Set usedColumns = new LinkedHashSet<>(); - - private final Map permissions = new LinkedHashMap<>(); - - private final Map stats = new LinkedHashMap<>(); + private final Map clarify = new LinkedHashMap<>(); private final Set warnings = new LinkedHashSet<>(); @@ -396,14 +356,10 @@ private AnswerTraceExplainView toView() { .answer(answer) .datasource(datasource) .sql(sql) - .sqlExplanation(sqlExplanation) .semanticHits(List.copyOf(semanticHits)) .knowledgeHits(List.copyOf(knowledgeHits)) .toolSteps(List.copyOf(toolSteps)) - .usedTables(List.copyOf(usedTables)) - .usedColumns(List.copyOf(usedColumns)) - .permissions(new LinkedHashMap<>(permissions)) - .stats(new LinkedHashMap<>(stats)) + .clarify(new LinkedHashMap<>(clarify)) .warnings(List.copyOf(warnings)) .updatedAt(updatedAt) .build(); @@ -430,8 +386,6 @@ public static class AnswerTraceExplainView { private String sql; - private String sqlExplanation; - @Builder.Default private List semanticHits = List.of(); @@ -442,16 +396,7 @@ public static class AnswerTraceExplainView { private List toolSteps = List.of(); @Builder.Default - private List usedTables = List.of(); - - @Builder.Default - private List usedColumns = List.of(); - - @Builder.Default - private Map permissions = Map.of(); - - @Builder.Default - private Map stats = Map.of(); + private Map clarify = Map.of(); @Builder.Default private List warnings = List.of(); @@ -528,9 +473,6 @@ public static class ExplainMirrorSummary { private String datasource; - @Builder.Default - private List usedTables = List.of(); - private int semanticHitCount; private int knowledgeHitCount; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java index 36ed52955..7d84e9508 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java @@ -30,17 +30,11 @@ record DomainKnowledgeSearchRequest(String query, List knowledgeTypes, I Double similarityThreshold) { } - record DomainKnowledgeSearchResult(String query, List appliedKnowledgeTypes, List hits, - List warnings, SearchDiagnostics diagnostics) { + record DomainKnowledgeSearchResult(List hits, List warnings, String resolution) { } record KnowledgeHit(String vectorType, String knowledgeId, String title, String summary, String snippet, String source, String concreteType) { } - record SearchDiagnostics(String runtimeAgentId, Integer recalledBusinessKnowledgeCount, - Integer recalledBusinessTermCount, Integer recalledAgentKnowledgeCount, - boolean businessKnowledgeVectorReady, boolean businessTermVectorReady, boolean agentKnowledgeVectorReady) { - } - } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java index 91e2c1ec0..343bc2691 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java @@ -26,7 +26,6 @@ import com.alibaba.cloud.ai.dataagent.service.vectorstore.AgentVectorStoreService; import com.alibaba.cloud.ai.dataagent.service.vectorstore.DynamicFilterService; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchRequest; -import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.SearchDiagnostics; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchResult; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.KnowledgeHit; import java.util.ArrayList; @@ -127,12 +126,9 @@ public DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchR warnings.add("未检索到匹配的业务知识,请缩短问题或换一种业务说法重试。"); } - SearchDiagnostics diagnostics = new SearchDiagnostics(agentId, businessTermDiagnostics.recalledCount(), - businessTermDiagnostics.recalledCount(), agentKnowledgeDiagnostics.recalledCount(), - businessTermDiagnostics.vectorReady(), businessTermDiagnostics.vectorReady(), - agentKnowledgeDiagnostics.vectorReady()); - DomainKnowledgeSearchResult result = new DomainKnowledgeSearchResult(query, - List.copyOf(options.appliedKnowledgeTypes()), List.copyOf(hits), List.copyOf(warnings), diagnostics); + String resolution = hits.isEmpty() ? "no_match" : "matched"; + DomainKnowledgeSearchResult result = new DomainKnowledgeSearchResult(List.copyOf(hits), List.copyOf(warnings), + resolution); if (graphRequest != null) { answerTraceExplainStore.recordKnowledgeSearch(graphRequest, result); } diff --git a/data-agent-management/src/main/resources/prompts/commonagent.md b/data-agent-management/src/main/resources/prompts/commonagent.md index 163476e9a..01cd588f4 100644 --- a/data-agent-management/src/main/resources/prompts/commonagent.md +++ b/data-agent-management/src/main/resources/prompts/commonagent.md @@ -33,7 +33,7 @@ 6. 如果你已经准备了候选 SQL,且答案将基于 SQL 返回给用户,在执行 SQL 前先调用 `sql_guard.check`,传 `action=SQL_VERIFY`。 必传:`query`、`sql`。 - 可选:`tableSchemas`、`semanticHits`、`businessKnowledgeHits`。 + 不要传旧的透传字段;当前校验只基于本轮顶层 `query` 和候选 `sql`。 7. `sql_guard.check` 是统一 SQL 工具。 `action=SQL_VERIFY`:只做结构与意图校验,不负责自动修复、不负责执行报错修复、也不负责结果回看。 From 5a5493e0e0ec2c47b1bc633cce0bb05ea9542231 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 29 Apr 2026 12:51:20 +0800 Subject: [PATCH 17/22] refactor: rename and enhance from --- data-agent-frontend/src/services/chat.ts | 22 +- data-agent-frontend/src/services/graph.ts | 10 +- .../src/services/sessionStateManager.ts | 14 +- data-agent-frontend/src/views/AgentRun.vue | 346 ++++++++++++++---- .../{GraphRequest.java => AgentRequest.java} | 2 +- .../runtime/AgentRuntimeEventPublisher.java | 4 +- .../runtime/AgentRuntimeExtensionFactory.java | 6 +- .../runtime/AgentScopeHookFactory.java | 4 +- .../runtime/AgentScopeStreamingHook.java | 4 +- .../agentscope/runtime/HumanFeedbackHook.java | 6 +- .../SpringToolCallbackAgentAdapter.java | 8 +- .../runtime/ToolContextRequestResolver.java | 14 +- .../agentscope/service/AgentService.java | 6 +- .../impl/AiAgentRuntimeServiceImpl.java | 54 +-- .../datasource/DatasourceExplorerResult.java | 19 + .../datasource/DatasourceExplorerService.java | 143 +++++++- .../DatasourceExplorerToolProvider.java | 6 +- .../DomainBusinessKnowledgeToolSupport.java | 6 +- .../semantic/SemanticModelSearchService.java | 12 +- .../semantic/SemanticModelToolSupport.java | 6 +- .../tool/sqlguard/SqlGuardToolProvider.java | 8 +- ...phNodeResponse.java => AgentResponse.java} | 10 +- .../dataagent/controller/ChatController.java | 35 ++ .../controller/DataAgentController.java | 24 +- .../AnswerTraceExplainStore.java | 105 +++++- .../DomainKnowledgeSearchService.java | 5 +- .../DomainKnowledgeSearchServiceImpl.java | 12 +- .../service/langfuse/LangfuseService.java | 4 +- 28 files changed, 692 insertions(+), 203 deletions(-) rename data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/{GraphRequest.java => AgentRequest.java} (97%) rename data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/{GraphNodeResponse.java => AgentResponse.java} (84%) diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index 248d446a1..b360a1845 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -76,14 +76,17 @@ export interface AnswerTraceExplain { answer?: string; datasource?: string; sql?: string; - sqlExplanation?: string; + decisionReason?: string; + resultScope?: string; + usedTables: string[]; + usedColumns: string[]; + relationEvidence: Record[]; + toolDecisionReasons: string[]; + resultScopeDetails: string[]; semanticHits: AnswerTraceSemanticHit[]; knowledgeHits: AnswerTraceKnowledgeHit[]; toolSteps: AnswerTraceToolStep[]; - usedTables: string[]; - usedColumns: string[]; - permissions: Record; - stats: Record; + clarify?: Record; warnings: string[]; updatedAt: number; } @@ -190,6 +193,15 @@ class ChatService { return response.data; } + async getLatestAnswerExplain(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get( + `${API_BASE_URL}/sessions/${sessionId}/answers/latest/explain`, + { params: { agentId: resolvedAgentId } }, + ); + return response.data; + } + async getAnswerExplain( sessionId: string, runtimeRequestId: string, diff --git a/data-agent-frontend/src/services/graph.ts b/data-agent-frontend/src/services/graph.ts index 1ac3ffb0b..f2a4fe00c 100644 --- a/data-agent-frontend/src/services/graph.ts +++ b/data-agent-frontend/src/services/graph.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -export interface GraphRequest { +export interface AgentRequest { agentId: string; threadId?: string; runtimeRequestId?: string; @@ -35,7 +35,7 @@ export interface ClarifyMetadata { summary?: string; } -export interface GraphNodeResponse { +export interface AgentResponse { agentId: string; threadId: string; nodeName: string; @@ -68,8 +68,8 @@ class GraphService { * @returns 关闭连接的函数 */ async streamSearch( - request: GraphRequest, - onMessage: (response: GraphNodeResponse) => Promise, + request: AgentRequest, + onMessage: (response: AgentResponse) => Promise, onError?: (error: Error) => Promise, onComplete?: () => Promise, ): Promise<() => void> { @@ -111,7 +111,7 @@ class GraphService { return; } try { - const nodeResponse: GraphNodeResponse = JSON.parse(event.data); + const nodeResponse: AgentResponse = JSON.parse(event.data); console.log( `Node: ${nodeResponse.nodeName}, message: ${nodeResponse.text}, type: ${nodeResponse.textType}`, ); diff --git a/data-agent-frontend/src/services/sessionStateManager.ts b/data-agent-frontend/src/services/sessionStateManager.ts index d6e946ced..9584a0d9d 100644 --- a/data-agent-frontend/src/services/sessionStateManager.ts +++ b/data-agent-frontend/src/services/sessionStateManager.ts @@ -15,7 +15,7 @@ */ import { ref, Ref } from 'vue'; -import { GraphNodeResponse, GraphRequest } from '@/services/graph.ts'; +import { AgentResponse, AgentRequest } from '@/services/graph.ts'; import { AnswerTraceExplain } from '@/services/chat.ts'; export interface PendingClarifyState { @@ -29,10 +29,10 @@ export interface PendingClarifyState { export interface SessionRuntimeState { isStreaming: boolean; - nodeBlocks: GraphNodeResponse[][]; + nodeBlocks: AgentResponse[][]; persistedBlockCount: number; closeStream: (() => void) | null; - lastRequest: GraphRequest | null; + lastRequest: AgentRequest | null; pendingClarify: PendingClarifyState | null; htmlReportContent: string; htmlReportSize: number; @@ -43,9 +43,9 @@ export interface SessionRuntimeState { // 可持久化的状态字段(不包括函数和临时状态) interface PersistableState { - nodeBlocks: GraphNodeResponse[][]; + nodeBlocks: AgentResponse[][]; persistedBlockCount: number; - lastRequest: GraphRequest | null; + lastRequest: AgentRequest | null; pendingClarify: PendingClarifyState | null; htmlReportContent: string; htmlReportSize: number; @@ -160,7 +160,7 @@ export function useSessionStateManager() { sessionId: string, viewState: { isStreaming: Ref; - nodeBlocks: Ref; + nodeBlocks: Ref; answerExplain?: Ref; answerExplainVisible?: Ref; pendingClarify?: Ref; @@ -187,7 +187,7 @@ export function useSessionStateManager() { sessionId: string, viewState: { isStreaming: Ref; - nodeBlocks: Ref; + nodeBlocks: Ref; answerExplain?: Ref; answerExplainVisible?: Ref; pendingClarify?: Ref; diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 652382843..db35ca350 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -634,8 +634,82 @@
+
+
澄清来源
+
+ 以下信息说明系统在正式查库前如何判断问题是否存在歧义,以及是否需要你补充信息。 +
+
+
+ 澄清结论 + {{ answerExplain.clarify.summary }} +
+
+ 风险等级 + {{ answerExplain.clarify.riskLevel }} +
+
+ 是否需要澄清 + + {{ answerExplain.clarify.clarifyRequired ? '是' : '否' }} + +
+
+ 是否阻止直接查库 + + {{ answerExplain.clarify.shouldBlockExecution ? '是' : '否' }} + +
+
+ 缺失维度 + + {{ asExplainStringList(answerExplain.clarify.missingDimensions).join('、') }} + +
+
+ 追问建议 + + {{ asExplainStringList(answerExplain.clarify.followUpQuestions).join(';') }} + +
+
+ 建议假设 + + {{ asExplainStringList(answerExplain.clarify.suggestedAssumptions).join(';') }} + +
+
+ 人工补充 + + {{ answerExplain.clarify.humanFeedbackContent }} + +
+
+
+
-
问题理解
+
语义模型来源
+
+ 以下命中来自语义模型召回,用于补充表、字段和业务语义理解。 +
-
-
SQL 解释
-
- {{ answerExplain.sqlExplanation }} -
-
{{ answerExplain.sql }}
+
+
执行 SQL
+
{{ answerExplain.sql }}
@@ -701,26 +769,86 @@
使用字段 - - {{ answerExplain.usedColumns.join('、') }} - + {{ answerExplain.usedColumns.join('、') }} +
+
+ 工具决策来源 + {{ answerExplain.decisionReason }}
- {{ key }} - {{ formatExplainValue(value) }} + 工具决策细节 + +
    +
  • + {{ reason }} +
  • +
+
+
+
+ 结果裁剪来源 + {{ answerExplain.resultScope }}
- {{ key }} - {{ formatExplainValue(value) }} + 结果裁剪细节 + +
    +
  • + {{ detail }} +
  • +
+
+
+
+ 语义模型命中 + {{ answerExplain.semanticHits.length }} 条
+
+ RAG / 知识命中 + {{ answerExplain.knowledgeHits.length }} 条 +
+
+ + +
+
关联来源
+
+ 以下信息说明本轮多表查询优先依据哪些物理外键或逻辑关系完成关联。 +
+
+
+
+ {{ relation.sourceTable }}.{{ relation.sourceColumn }} → {{ relation.targetTable }}.{{ relation.targetColumn }} +
+
+ {{ relation.sourceType || '-' }} + {{ relation.relationType }} + 数据库声明 + 逻辑关系 +
+
+ {{ relation.description }} +
+
@@ -749,7 +877,10 @@
-
知识来源
+
RAG / 知识来源
+
+ 以下命中来自业务知识库、FAQ、文档切片等 RAG 召回结果,会直接影响最终回答。 +
([]); + const nodeBlocks = ref([]); const options = ref({ markdownIt: { linkify: true, @@ -1073,6 +1204,7 @@ session.id, requireResolvedAgentId(), ); + await preloadSessionLatestObservability(session.id); scrollToBottom(); } catch (error) { ElMessage.error('加载消息失败'); @@ -1080,6 +1212,25 @@ } }; + const preloadSessionLatestObservability = async (sessionId: string) => { + if (!currentSession.value || currentSession.value.id !== sessionId) { + return; + } + try { + await loadLatestTrace(); + } catch (error) { + console.warn('预加载最新 trace 失败:', error); + } + if (!currentSession.value || currentSession.value.id !== sessionId) { + return; + } + try { + await loadLatestAnswerExplain({ visible: false }); + } catch (error) { + console.warn('预加载最新数据来源失败:', error); + } + }; + const buildPendingClarifyState = ( metadata: ClarifyMetadata & { originalQuery: string }, ): PendingClarifyState => { @@ -1140,7 +1291,7 @@ currentMessages.value.push(savedMessage); getSessionState(currentSession.value.id); - const request: GraphRequest = { + const request: AgentRequest = { agentId: String(requireResolvedAgentId()), query: requestQuery, humanFeedback: Boolean(activeClarify), @@ -1155,7 +1306,7 @@ pendingClarify.value = null; getSessionState(currentSession.value.id).pendingClarify = null; - await sendGraphRequest(request); + await sendAgentRequest(request); } catch (error) { if (activeClarify && currentSession.value) { pendingClarify.value = activeClarify; @@ -1169,8 +1320,8 @@ }; const buildExplainMetadataForNode = ( - request: GraphRequest, - node: GraphNodeResponse[], + request: AgentRequest, + node: AgentResponse[], ): MessageExplainMetadata | null => { if (!request.runtimeRequestId || !node.length) { return null; @@ -1190,8 +1341,8 @@ const saveAssistantNodeMessage = async ( sessionId: string, - node: GraphNodeResponse[], - request?: GraphRequest | null, + node: AgentResponse[], + request?: AgentRequest | null, ): Promise => { if (!node || !node.length) { return; @@ -1229,7 +1380,7 @@ await ChatService.saveMessage(sessionId, requireResolvedAgentId(), aiMessage); }; - const sendGraphRequest = async (request: GraphRequest) => { + const sendAgentRequest = async (request: AgentRequest) => { const sessionId = currentSession.value!.id; currentSession.value!.title; const sessionState = getSessionState(sessionId); @@ -1245,7 +1396,7 @@ // 重置报告状态 resetReportState(sessionState, request); - const saveNodeMessage = (node: GraphNodeResponse[]): Promise => { + const saveNodeMessage = (node: AgentResponse[]): Promise => { return saveAssistantNodeMessage(sessionId, node, request).catch(error => { console.error('保存AI消息失败:', error); }); @@ -1265,7 +1416,7 @@ const closeStream = await GraphService.streamSearch( request, - (response: GraphNodeResponse) => { + (response: AgentResponse) => { if (response.error) { ElMessage.error(`处理错误: ${response.text}`); return; @@ -1295,7 +1446,7 @@ } // 创建新的节点块 - const newBlock: GraphNodeResponse = { + const newBlock: AgentResponse = { ...response, text: response.text, }; @@ -1309,8 +1460,8 @@ sessionState.htmlReportSize = sessionState.htmlReportContent.length; // 更新显示:当前已经收集了多少字节的报告 - const reportNode: GraphNodeResponse[] = sessionState.nodeBlocks.find( - (block: GraphNodeResponse[]) => + const reportNode: AgentResponse[] = sessionState.nodeBlocks.find( + (block: AgentResponse[]) => block.length > 0 && block[0].nodeName === 'ReportGeneratorNode' && block[0].textType === 'HTML', @@ -1329,8 +1480,8 @@ // 处理Markdown报告 else if (response.textType === 'MARK_DOWN') { sessionState.markdownReportContent += response.text; - const reportNode: GraphNodeResponse[] = sessionState.nodeBlocks.find( - (block: GraphNodeResponse[]) => + const reportNode: AgentResponse[] = sessionState.nodeBlocks.find( + (block: AgentResponse[]) => block.length > 0 && block[0].nodeName === 'ReportGeneratorNode' && block[0].textType === 'MARK_DOWN', @@ -1353,7 +1504,7 @@ pendingSavePromises.push(savePromise); } // 创建新的节点块 - const newBlock: GraphNodeResponse = { + const newBlock: AgentResponse = { ...response, text: response.text, }; @@ -1372,7 +1523,7 @@ } // 创建新的节点块 - const newBlock: GraphNodeResponse = { + const newBlock: AgentResponse = { ...response, text: response.text, }; @@ -1382,14 +1533,14 @@ } else { // 继续当前节点的内容 if (currentBlockIndex >= 0 && sessionState.nodeBlocks[currentBlockIndex]) { - const newBlock: GraphNodeResponse = { + const newBlock: AgentResponse = { ...response, text: response.text, }; sessionState.nodeBlocks[currentBlockIndex].push(newBlock); } else { // 创建新的节点块 - const newBlock: GraphNodeResponse = { + const newBlock: AgentResponse = { ...response, text: response.text, }; @@ -1595,7 +1746,7 @@ }; // 生成节点容器的HTML代码 - const generateNodeHtml = (node: GraphNodeResponse[]) => { + const generateNodeHtml = (node: AgentResponse[]) => { const content = formatNodeContent(node); return ` @@ -1606,7 +1757,7 @@ `; }; - const formatNodeContent = (node: GraphNodeResponse[]) => { + const formatNodeContent = (node: AgentResponse[]) => { let content = ''; for (let idx = 0; idx < node.length; idx++) { @@ -1713,7 +1864,7 @@ }; // 重置报告状态 - const resetReportState = (sessionState: SessionRuntimeState, request: GraphRequest) => { + const resetReportState = (sessionState: SessionRuntimeState, request: AgentRequest) => { sessionState.isStreaming = true; sessionState.nodeBlocks = []; sessionState.persistedBlockCount = 0; @@ -1795,20 +1946,40 @@ return null; }); - const latestExplainRuntimeRequestId = computed(() => { + const latestExplainRuntimeRequestId = computed(() => sessionTrace.value?.runtimeRequestId ?? null); + + const loadLatestAnswerExplain = async (options?: { visible?: boolean }) => { if (!currentSession.value) { - return null; + return; } - const sessionState = getSessionState(currentSession.value.id); - if (sessionState.lastRequest?.runtimeRequestId) { - return sessionState.lastRequest.runtimeRequestId; + if (options?.visible ?? true) { + answerExplainVisible.value = true; } - const message = latestExplainMessage.value; - if (!message) { - return null; + answerExplainLoading.value = true; + answerExplainError.value = ''; + try { + answerExplain.value = await ChatService.getLatestAnswerExplain( + currentSession.value.id, + requireResolvedAgentId(), + ); + saveViewToState(currentSession.value.id, { + isStreaming, + nodeBlocks, + answerExplain, + answerExplainVisible, + }); + } catch (error: any) { + answerExplain.value = null; + if (error?.response?.status === 404) { + answerExplainError.value = '当前会话还没有可查看的数据来源。'; + } else { + answerExplainError.value = '加载数据来源失败,请稍后重试。'; + } + console.error('加载最新 answer explain 失败:', error); + } finally { + answerExplainLoading.value = false; } - return parseMessageMetadata(message)?.runtimeRequestId ?? null; - }); + }; const loadAnswerExplainByRuntimeRequestId = async (runtimeRequestId: string) => { if (!currentSession.value) { @@ -1844,12 +2015,12 @@ }; const openLatestAnswerExplain = async () => { - const runtimeRequestId = latestExplainRuntimeRequestId.value; - if (!runtimeRequestId) { + if (!currentSession.value) { ElMessage.warning('当前会话还没有可查看的数据来源'); return; } - await loadAnswerExplainByRuntimeRequestId(runtimeRequestId); + answerExplainVisible.value = true; + await Promise.all([loadLatestTrace(), loadLatestAnswerExplain()]); }; const formatExplainValue = (value: unknown) => { @@ -1867,21 +2038,48 @@ } }; + const asExplainStringList = (value: unknown): string[] => { + if (!Array.isArray(value)) { + return []; + } + return value + .filter((item): item is string => typeof item === 'string' && item.trim().length > 0) + .map((item) => item.trim()); + }; + const summarizeExplainExecution = (explain: AnswerTraceExplain | null) => { if (!explain) { return '当前还没有可展示的执行说明。'; } - if (explain.datasource || explain.usedTables.length > 0 || explain.usedColumns.length > 0) { - return '本轮回答访问了结构化数据源,下面展示的是实际使用到的表、字段、权限范围和统计信息。'; + const hasDatasourceEvidence = + Boolean(explain.datasource || explain.sql) || + explain.usedTables.length > 0 || + explain.usedColumns.length > 0; + const hasSemanticEvidence = explain.semanticHits.length > 0; + const hasKnowledgeEvidence = explain.knowledgeHits.length > 0; + if (hasDatasourceEvidence && hasSemanticEvidence && hasKnowledgeEvidence) { + return '本轮回答同时使用了结构化数据源、语义模型召回和 RAG/知识召回,下面展示的是完整来源。'; + } + if (hasDatasourceEvidence && hasSemanticEvidence) { + return '本轮回答同时使用了结构化数据源和语义模型召回,下面展示的是 SQL 来源与语义来源。'; + } + if (hasDatasourceEvidence && hasKnowledgeEvidence) { + return '本轮回答同时使用了结构化数据源和 RAG/知识召回,下面展示的是 SQL 来源与知识来源。'; + } + if (hasDatasourceEvidence) { + return '本轮回答访问了结构化数据源,下面展示的是实际执行过的数据源步骤、SQL、使用表和使用字段。'; + } + if (hasKnowledgeEvidence && hasSemanticEvidence) { + return '本轮回答没有直接查库,但同时命中了语义模型和 RAG/知识召回,回答受这些来源共同影响。'; } - if (explain.knowledgeHits.length > 0) { - return '本轮回答没有访问数据库,但命中了知识检索结果,回答受这些知识片段影响。'; + if (hasKnowledgeEvidence) { + return '本轮回答没有直接查库,但命中了 RAG/知识召回结果,回答受这些知识片段影响。'; } - if (explain.semanticHits.length > 0) { + if (hasSemanticEvidence) { return '本轮回答没有直接查库,但命中了语义模型,用来帮助系统理解你的问题和业务字段。'; } - if (explain.toolSteps.length > 0) { - return '本轮回答没有形成可展示的数据源明细,但系统执行过工具步骤,详细过程可在下方执行过程里查看。'; + if (explain.toolSteps.length > 0 || (explain.clarify && Object.keys(explain.clarify).length > 0)) { + return '本轮回答没有形成可展示的查库明细,但系统执行过澄清或其他工具步骤,详细过程可在下方查看。'; } return '本轮回答没有访问数据库、知识库或其他可回放工具,当前结果主要来自模型直接生成。'; }; @@ -2278,7 +2476,7 @@ // 保存已接收的节点消息 if (sessionState.nodeBlocks && sessionState.nodeBlocks.length > 0) { - const saveNodeMessage = (node: GraphNodeResponse[]): Promise => { + const saveNodeMessage = (node: AgentResponse[]): Promise => { return saveAssistantNodeMessage(sessionId, node, sessionState.lastRequest).catch( error => { console.error('保存AI消息失败:', error); @@ -2383,7 +2581,7 @@ }; // 从节点块中提取 Markdown 内容 - const getMarkdownContentFromNode = (node: GraphNodeResponse[]): string => { + const getMarkdownContentFromNode = (node: AgentResponse[]): string => { if (!node || node.length === 0) { return ''; } @@ -2472,6 +2670,7 @@ formatTraceTime, formatTraceOffset, formatExplainValue, + asExplainStringList, summarizeExplainExecution, isStructuredTraceValue, formatStructuredTraceValue, @@ -3189,6 +3388,17 @@ line-height: 1.8; } + .answer-explain-inline-list { + margin: 0; + padding-left: 20px; + color: #1e3a5f; + line-height: 1.8; + } + + .answer-explain-inline-list li + li { + margin-top: 6px; + } + .answer-explain-warning-list { margin: 0; padding-left: 20px; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java similarity index 97% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java index 4027c9ce6..0fbe4d7c1 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java @@ -24,7 +24,7 @@ @AllArgsConstructor @NoArgsConstructor @Builder -public class GraphRequest { +public class AgentRequest { private String agentId; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java index 5d9fa2332..2d185b905 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java @@ -15,11 +15,11 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; @FunctionalInterface public interface AgentRuntimeEventPublisher { - void publish(GraphNodeResponse response); + void publish(AgentResponse response); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java index 95e8b75eb..67956f934 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRuntimeExtensions; import io.agentscope.core.hook.Hook; import io.agentscope.core.memory.Memory; @@ -42,8 +42,8 @@ public class AgentRuntimeExtensionFactory { private final AgentScopeSkillBoxFactory skillBoxFactory; - public AgentRuntimeExtensions create(GraphRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher, - Map toolCallbacks) { + public AgentRuntimeExtensions create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher, + Map toolCallbacks) { Toolkit toolkit = toolkitFactory.buildToolkit(toolCallbacks); SkillBox skillBox = skillBoxFactory.create(request.getAgentId(), toolkit); Memory memory = memoryFactory.create(request.getThreadId()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java index 2e733b251..fb5b0236d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; @@ -36,7 +36,7 @@ public class AgentScopeHookFactory { private final AgentSessionRegistry sessionRegistry; - public List create(GraphRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher) { + public List create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher) { List hooks = new ArrayList<>(); if (eventPublisher != null) { hooks.add(new AgentScopeStreamingHook(request.getAgentId(), request.getThreadId(), request.isNl2sqlOnly(), diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java index 123ff1438..278ef2ea5 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import com.alibaba.cloud.ai.dataagent.enums.TextType; import io.agentscope.core.hook.ActingChunkEvent; import io.agentscope.core.hook.Hook; @@ -76,7 +76,7 @@ private void emit(String nodeName, String text) { if (text == null || text.isBlank()) { return; } - eventPublisher.publish(GraphNodeResponse.builder() + eventPublisher.publish(AgentResponse.builder() .agentId(agentId) .threadId(threadId) .nodeName(nodeName) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java index 33d3d4328..92ed531b7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import io.agentscope.core.hook.Hook; import io.agentscope.core.hook.HookEvent; import io.agentscope.core.hook.PreReasoningEvent; @@ -42,7 +42,7 @@ private HumanFeedbackHook(boolean pauseAfterPlanning, boolean replayRequested, S this.feedbackDirective = feedbackDirective; } - public static HumanFeedbackHook from(GraphRequest request) { + public static HumanFeedbackHook from(AgentRequest request) { if (request.isNl2sqlOnly()) { return null; } @@ -74,7 +74,7 @@ public Mono onEvent(T event) { return Mono.just(event); } - private static String buildDirective(GraphRequest request) { + private static String buildDirective(AgentRequest request) { StringBuilder builder = new StringBuilder("Human review directive:"); if (request.isRejectedPlan()) { builder.append("\n- The previous plan was rejected. Re-plan before continuing."); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java index d4ed94b20..0cd03c32c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.agentscope.core.message.TextBlock; @@ -88,9 +88,9 @@ private ToolContext toToolContext(ToolCallParam toolCallParam) { if (toolCallParam.getContext() != null) { contextMap.put("agentScopeContext", toolCallParam.getContext()); ToolExecutionContext toolExecutionContext = toolCallParam.getContext(); - GraphRequest graphRequest = toolExecutionContext.get("graphRequest", GraphRequest.class); - if (graphRequest != null) { - contextMap.put("graphRequest", graphRequest); + AgentRequest agentRequest = toolExecutionContext.get("agentRequest", AgentRequest.class); + if (agentRequest != null) { + contextMap.put("graphRequest", agentRequest); } AgentRuntimeRequestMetadata requestMetadata = toolExecutionContext.get(AgentRuntimeRequestMetadata.class); if (requestMetadata != null) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java index df617bd75..9f671994e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import io.agentscope.core.tool.ToolExecutionContext; import java.util.Map; import org.springframework.ai.chat.model.ToolContext; @@ -28,22 +28,22 @@ private ToolContextRequestResolver() { } @Nullable - public static GraphRequest resolveGraphRequest(@Nullable ToolContext toolContext) { + public static AgentRequest resolveGraphRequest(@Nullable ToolContext toolContext) { if (toolContext == null || toolContext.getContext() == null) { return null; } Map context = toolContext.getContext(); Object graphRequest = context.get("graphRequest"); - if (graphRequest instanceof GraphRequest request) { + if (graphRequest instanceof AgentRequest request) { return request; } Object agentScopeContext = context.get("agentScopeContext"); if (agentScopeContext instanceof ToolExecutionContext toolExecutionContext) { - GraphRequest request = toolExecutionContext.get("graphRequest", GraphRequest.class); + AgentRequest request = toolExecutionContext.get("graphRequest", AgentRequest.class); if (request != null) { return request; } - GraphRequest metadataRequest = fromMetadata(toolExecutionContext.get(AgentRuntimeRequestMetadata.class)); + AgentRequest metadataRequest = fromMetadata(toolExecutionContext.get(AgentRuntimeRequestMetadata.class)); if (metadataRequest != null) { return metadataRequest; } @@ -56,12 +56,12 @@ public static GraphRequest resolveGraphRequest(@Nullable ToolContext toolContext } @Nullable - private static GraphRequest fromMetadata(@Nullable AgentRuntimeRequestMetadata metadata) { + private static AgentRequest fromMetadata(@Nullable AgentRuntimeRequestMetadata metadata) { if (metadata == null || !StringUtils.hasText(metadata.threadId()) || !StringUtils.hasText(metadata.runtimeRequestId())) { return null; } - return GraphRequest.builder() + return AgentRequest.builder() .agentId(metadata.agentId()) .threadId(metadata.threadId()) .runtimeRequestId(metadata.runtimeRequestId()) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java index 764d3b3c8..f05393e70 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java @@ -15,8 +15,8 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.service; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import org.springframework.http.codec.ServerSentEvent; import reactor.core.publisher.Sinks; @@ -24,7 +24,7 @@ public interface AgentService { String nl2sql(String naturalQuery, String agentId); - void graphStreamProcess(Sinks.Many> sink, GraphRequest graphRequest); + void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest); void stopStreamProcessing(String threadId, String runtimeRequestId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index e7577f445..1234337b7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.service.impl; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeEventPublisher; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeExtensionFactory; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService; @@ -28,7 +28,7 @@ import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRuntimeExtensions; import com.alibaba.cloud.ai.dataagent.agentscope.template.ManagedAgent; import com.alibaba.cloud.ai.dataagent.agentscope.template.ManagedAgentRegistry; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import com.alibaba.cloud.ai.dataagent.constant.AgentRuntimeConstant; import com.alibaba.cloud.ai.dataagent.enums.ModelType; import com.alibaba.cloud.ai.dataagent.enums.TextType; @@ -116,7 +116,7 @@ public class AiAgentRuntimeServiceImpl implements AgentService { @Override public String nl2sql(String naturalQuery, String agentId) { log.info("NL2SQL runtime invoked for agentId={}", agentId); - GraphRequest request = GraphRequest.builder().agentId(agentId).query(naturalQuery).nl2sqlOnly(true).build(); + AgentRequest request = AgentRequest.builder().agentId(agentId).query(naturalQuery).nl2sqlOnly(true).build(); initializeRuntimeRequest(request); sessionRegistry.register(request.getThreadId(), request.getRuntimeRequestId()); try { @@ -128,10 +128,10 @@ public String nl2sql(String naturalQuery, String agentId) { } @Override - public void graphStreamProcess(Sinks.Many> sink, GraphRequest graphRequest) { - initializeRuntimeRequest(graphRequest); - String threadId = graphRequest.getThreadId(); - String runtimeRequestId = graphRequest.getRuntimeRequestId(); + public void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest) { + initializeRuntimeRequest(agentRequest); + String threadId = agentRequest.getThreadId(); + String runtimeRequestId = agentRequest.getRuntimeRequestId(); StreamTextTracker streamTextTracker = new StreamTextTracker(); sessionRegistry.register(threadId, runtimeRequestId); AgentRuntimeEventPublisher eventPublisher = response -> { @@ -145,11 +145,11 @@ public void graphStreamProcess(Sinks.Many> si sink.tryEmitNext(ServerSentEvent.builder(response).event(STREAM_EVENT_MESSAGE).build()); }; - Mono.fromCallable(() -> executeAgent(graphRequest, eventPublisher)) + Mono.fromCallable(() -> executeAgent(agentRequest, eventPublisher)) .doFinally(signalType -> sessionRegistry.finish(threadId, runtimeRequestId)) .subscribeOn(Schedulers.boundedElastic()) - .subscribe(result -> emitSuccess(sink, graphRequest, result, streamTextTracker), - error -> emitError(sink, graphRequest, error)); + .subscribe(result -> emitSuccess(sink, agentRequest, result, streamTextTracker), + error -> emitError(sink, agentRequest, error)); } @Override @@ -157,15 +157,15 @@ public void stopStreamProcessing(String threadId, String runtimeRequestId) { sessionRegistry.markCancelled(threadId, runtimeRequestId); } - private void emitSuccess(Sinks.Many> sink, GraphRequest request, String result, - StreamTextTracker streamTextTracker) { + private void emitSuccess(Sinks.Many> sink, AgentRequest request, String result, + StreamTextTracker streamTextTracker) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); if (!sessionRegistry.isActive(threadId, runtimeRequestId)) { return; } if (shouldEmitFinalResponse(result, streamTextTracker)) { - GraphNodeResponse response = GraphNodeResponse.builder() + AgentResponse response = AgentResponse.builder() .agentId(request.getAgentId()) .threadId(threadId) .nodeName(RUNTIME_NODE_NAME) @@ -174,7 +174,7 @@ private void emitSuccess(Sinks.Many> sink, Gr .build(); sink.tryEmitNext(ServerSentEvent.builder(response).event(STREAM_EVENT_MESSAGE).build()); } - sink.tryEmitNext(ServerSentEvent.builder(GraphNodeResponse.complete(request.getAgentId(), threadId)) + sink.tryEmitNext(ServerSentEvent.builder(AgentResponse.complete(request.getAgentId(), threadId)) .event(STREAM_EVENT_COMPLETE) .build()); sink.tryEmitComplete(); @@ -184,7 +184,7 @@ private boolean shouldEmitFinalResponse(String result, StreamTextTracker streamT return StringUtils.hasText(result) && !streamTextTracker.containsFinalAnswer(result); } - private void emitError(Sinks.Many> sink, GraphRequest request, Throwable error) { + private void emitError(Sinks.Many> sink, AgentRequest request, Throwable error) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); if (sessionRegistry.isCancelled(threadId, runtimeRequestId)) { @@ -195,18 +195,18 @@ private void emitError(Sinks.Many> sink, Grap log.error("AgentScope runtime failed, threadId={}", threadId, error); if (sessionRegistry.isActive(threadId, runtimeRequestId)) { String message = error.getMessage() == null ? "AgentScope 运行失败。" : error.getMessage(); - sink.tryEmitNext(ServerSentEvent.builder(GraphNodeResponse.error(request.getAgentId(), threadId, message)) + sink.tryEmitNext(ServerSentEvent.builder(AgentResponse.error(request.getAgentId(), threadId, message)) .event(STREAM_EVENT_ERROR) .build()); sink.tryEmitComplete(); } } - private String executeAgent(GraphRequest request) { + private String executeAgent(AgentRequest request) { return executeAgent(request, null); } - private void initializeRuntimeRequest(GraphRequest request) { + private void initializeRuntimeRequest(AgentRequest request) { if (!StringUtils.hasText(request.getThreadId())) { request.setThreadId(UUID.randomUUID().toString()); } @@ -215,7 +215,7 @@ private void initializeRuntimeRequest(GraphRequest request) { } } - private String executeAgent(GraphRequest request, AgentRuntimeEventPublisher eventPublisher) { + private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eventPublisher) { sessionRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); answerTraceExplainStore.openScope(request); Span rootSpan = startRuntimeSpan(request); @@ -283,7 +283,7 @@ && isInterruptedCancellation(ex)) { } } - private Span startRuntimeSpan(GraphRequest request) { + private Span startRuntimeSpan(AgentRequest request) { Span span = tracer.spanBuilder(ROOT_SPAN_NAME).startSpan(); span.setAttribute(SessionTraceStore.ATTR_THREAD_ID, request.getThreadId()); span.setAttribute(SessionTraceStore.ATTR_RUNTIME_REQUEST_ID, request.getRuntimeRequestId()); @@ -302,8 +302,8 @@ private void recordRuntimeFailure(Span rootSpan, Throwable throwable) { rootSpan.recordException(throwable); } - private String blockForClarification(GraphRequest request, AgentRuntimeEventPublisher eventPublisher, Span rootSpan, - QueryClarifyAssessment clarifyAssessment) { + private String blockForClarification(AgentRequest request, AgentRuntimeEventPublisher eventPublisher, Span rootSpan, + QueryClarifyAssessment clarifyAssessment) { String clarifyText = clarifyAssessment.userMessage(); answerTraceExplainStore.recordFinalAnswer(clarifyText); persistAnswerExplainSnapshot(request); @@ -312,7 +312,7 @@ private String blockForClarification(GraphRequest request, AgentRuntimeEventPubl if (eventPublisher != null) { Map metadata = new LinkedHashMap<>(clarifyAssessment.toMetadata()); metadata.put("originalQuery", request.getQuery()); - eventPublisher.publish(GraphNodeResponse.builder() + eventPublisher.publish(AgentResponse.builder() .agentId(request.getAgentId()) .threadId(request.getThreadId()) .nodeName(RUNTIME_NODE_NAME) @@ -324,7 +324,7 @@ private String blockForClarification(GraphRequest request, AgentRuntimeEventPubl return clarifyText; } - private void mirrorExplainSummary(Span rootSpan, GraphRequest request) { + private void mirrorExplainSummary(Span rootSpan, AgentRequest request) { if (rootSpan == null || request == null) { return; } @@ -340,7 +340,7 @@ private void mirrorExplainSummary(Span rootSpan, GraphRequest request) { }); } - private void persistAnswerExplainSnapshot(GraphRequest request) { + private void persistAnswerExplainSnapshot(AgentRequest request) { if (request == null || !StringUtils.hasText(request.getThreadId()) || !StringUtils.hasText(request.getRuntimeRequestId())) { return; @@ -365,7 +365,7 @@ private void persistAnswerExplainSnapshot(GraphRequest request) { }); } - private String buildAnswerExplainMetadata(GraphRequest request) throws Exception { + private String buildAnswerExplainMetadata(AgentRequest request) throws Exception { Map metadata = new LinkedHashMap<>(); metadata.put("kind", "answer-explain"); metadata.put("runtimeRequestId", request.getRuntimeRequestId()); @@ -404,7 +404,7 @@ private String resolveManagedSystemPrompt(Agent agent, String requestAgentId) { return agent.getPrompt(); } - private String buildUserPrompt(GraphRequest request) { + private String buildUserPrompt(AgentRequest request) { return request.getQuery() == null ? "" : request.getQuery(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java index b12857581..0f83c29ed 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java @@ -47,4 +47,23 @@ public class DatasourceExplorerResult { private String sql; + @Builder.Default + private List usedTables = new ArrayList<>(); + + @Builder.Default + private List usedColumns = new ArrayList<>(); + + @Builder.Default + private List> relationEvidence = new ArrayList<>(); + + @Builder.Default + private List toolDecisionReasons = new ArrayList<>(); + + @Builder.Default + private List resultScopeDetails = new ArrayList<>(); + + private String resultScope; + + private String decisionReason; + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index b8c0f7b1b..9f4239986 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -19,7 +19,7 @@ import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ForeignKeyInfoBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ResultSetBO; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; import com.alibaba.cloud.ai.dataagent.connector.accessor.AccessorFactory; @@ -117,7 +117,7 @@ public DatasourceExplorerResult execute(String agentId, DatasourceExplorerReques } public DatasourceExplorerResult execute(String agentId, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) throws Exception { + @Nullable AgentRequest graphRequest) throws Exception { if (request == null || request.getAction() == null) { throw new IllegalArgumentException("数据源探索请求必须提供 action"); } @@ -133,7 +133,7 @@ public DatasourceExplorerResult execute(String agentId, DatasourceExplorerReques } private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest graphRequest) { int limit = normalizeLimit(request.getLimit()); Map tableDocumentMap = loadTableDocumentMap(context, context.visibleTables()); List> tables = context.visibleTables() @@ -150,7 +150,7 @@ private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceE } private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest graphRequest) { int limit = normalizeLimit(request.getLimit()); String query = StringUtils.trimToEmpty(request.getQuery()).toLowerCase(Locale.ROOT); Map tableDocumentMap = loadTableDocumentMap(context, context.visibleTables()); @@ -169,7 +169,7 @@ private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceE } private DatasourceExplorerResult getTableSchema(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) throws Exception { + @Nullable AgentRequest graphRequest) throws Exception { String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); List columns = context.accessor() .showColumns(context.dbConfig(), @@ -195,7 +195,7 @@ private DatasourceExplorerResult getTableSchema(ExplorerContext context, Datasou } private DatasourceExplorerResult getRelatedTables(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest graphRequest) { String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); List relations = filterRelations(context, tableName); List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); @@ -218,7 +218,7 @@ private DatasourceExplorerResult getRelatedTables(ExplorerContext context, Datas } private DatasourceExplorerResult previewRows(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) throws Exception { + @Nullable AgentRequest graphRequest) throws Exception { String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); int limit = normalizeLimit(request.getLimit()); String sql = SqlUtil.buildSelectSql(context.dbConfig().getDialectType(), @@ -231,12 +231,18 @@ private DatasourceExplorerResult previewRows(ExplorerContext context, Datasource .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) .sql(sql) + .usedTables(List.of(tableName)) + .usedColumns(toExplainColumns(resultSet)) + .resultScope("预览结果仅包含当前可见字段,并受 limit 限制。") + .resultScopeDetails(buildPreviewResultScopeDetails(tableName, resultSet, limit)) + .decisionReason("当前通过 PREVIEW_ROWS 直接预览单表数据,用于快速确认表内容。") + .toolDecisionReasons(buildPreviewDecisionReasons(tableName, limit)) .searchReady(true) .build(), graphRequest); } private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplorerRequest request, - @Nullable GraphRequest graphRequest) throws Exception { + @Nullable AgentRequest graphRequest) throws Exception { String rawSql = StringUtils.trimToNull(request.getSql()); if (rawSql == null) { throw new IllegalArgumentException("search action 必须提供 sql"); @@ -244,11 +250,21 @@ private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplo int limit = normalizeLimit(request.getLimit()); SqlGuardedQuery guardedQuery = guardReadonlySql(context, rawSql, limit); ResultSetBO resultSet = filterResultSet(executeSql(context, guardedQuery.sql()), guardedQuery); + List> relationEvidence = collectRelationEvidence(context, guardedQuery.referencedTables()); + List usedTables = toExplainTables(guardedQuery.referencedTables(), context.visibleTablesByName()); + List usedColumns = toExplainColumns(resultSet); return capture(baseResult(context, DatasourceExplorerAction.SEARCH, ("已执行只读查询,返回 %d 行结果".formatted(resultSet.getData().size())) + HIDDEN_FIELD_INFERENCE_WARNING) .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) .sql(guardedQuery.sql()) + .usedTables(usedTables) + .usedColumns(usedColumns) + .relationEvidence(relationEvidence) + .resultScope(buildResultScopeSummary(resultSet, limit)) + .resultScopeDetails(buildSearchResultScopeDetails(resultSet, limit, guardedQuery.allowedResultHeaders())) + .decisionReason(buildDecisionReason(guardedQuery.referencedTables(), relationEvidence)) + .toolDecisionReasons(buildSearchDecisionReasons(usedTables, relationEvidence, limit)) .searchReady(true) .build(), graphRequest); } @@ -631,6 +647,115 @@ private List> toColumnHeaders(ResultSetBO resultSet) { return resultSet.getColumn().stream().map(column -> Map.of("name", column)).toList(); } + private List toExplainColumns(ResultSetBO resultSet) { + return Optional.ofNullable(resultSet.getColumn()) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + } + + private List toExplainTables(Set referencedTables, Map> visibleTablesByName) { + if (referencedTables == null || referencedTables.isEmpty()) { + return List.of(); + } + return referencedTables.stream().map(tableName -> { + List candidates = visibleTablesByName.getOrDefault(normalizeTableName(tableName), List.of()); + return candidates.isEmpty() ? tableName : candidates.get(0); + }).distinct().toList(); + } + + private List> collectRelationEvidence(ExplorerContext context, Set referencedTables) { + if (referencedTables == null || referencedTables.size() < 2) { + return List.of(); + } + Set normalizedReferencedTables = referencedTables.stream() + .map(this::normalizeTableName) + .collect(Collectors.toCollection(LinkedHashSet::new)); + return context.unifiedRelations() + .stream() + .filter(relation -> normalizedReferencedTables.contains(normalizeTableName(relation.sourceTable())) + && normalizedReferencedTables.contains(normalizeTableName(relation.targetTable()))) + .map(this::toRelationEntry) + .toList(); + } + + private String buildResultScopeSummary(ResultSetBO resultSet, int limit) { + int rowCount = resultSet.getData() == null ? 0 : resultSet.getData().size(); + int columnCount = resultSet.getColumn() == null ? 0 : resultSet.getColumn().size(); + if (rowCount >= limit) { + return "当前结果仅展示前 %d 行、%d 个返回字段;字段范围已经过可见性裁剪。".formatted(limit, columnCount); + } + return "当前结果展示 %d 行、%d 个返回字段;字段范围已经过可见性裁剪。".formatted(rowCount, columnCount); + } + + private String buildDecisionReason(Set referencedTables, List> relationEvidence) { + List usedTables = toExplainTables(referencedTables, Map.of()); + if (!relationEvidence.isEmpty()) { + return "本轮选择执行 SQL,是因为问题需要跨表查数;表间关联优先依据已配置的物理外键或逻辑关系。"; + } + if (usedTables.size() > 1) { + return "本轮选择执行 SQL,是因为问题需要联合多张表获取结构化结果。"; + } + if (usedTables.size() == 1) { + return "本轮选择执行 SQL,是因为问题需要直接从目标表提取结构化结果。"; + } + return "本轮选择执行 SQL,是因为问题需要结构化数据结果来支撑回答。"; + } + + private List buildPreviewDecisionReasons(String tableName, int limit) { + return List.of("本轮直接选择 PREVIEW_ROWS,是为了快速确认单表“%s”的样例数据。".formatted(tableName), + "预览查询会自动附带 limit=%d,避免一次返回过多行。".formatted(limit)); + } + + private List buildSearchDecisionReasons(List usedTables, + List> relationEvidence, int limit) { + List reasons = new ArrayList<>(); + reasons.add("本轮选择执行 SQL,是因为回答需要结构化结果来支撑结论。"); + if (!usedTables.isEmpty()) { + reasons.add("实际命中的表有:%s。".formatted(String.join("、", usedTables))); + } + if (usedTables.size() > 1) { + reasons.add("由于问题涉及多张表,系统需要联合查询后再生成答案。"); + } + if (!relationEvidence.isEmpty()) { + reasons.add("多表关联优先依据已配置的物理外键或逻辑关系,而不是临时猜测关联条件。"); + } + reasons.add("查询结果默认限制为最多 %d 行,避免一次返回过多数据。".formatted(limit)); + return List.copyOf(reasons); + } + + private List buildPreviewResultScopeDetails(String tableName, ResultSetBO resultSet, int limit) { + List details = new ArrayList<>(); + details.add("当前展示的是表“%s”的预览结果,不代表完整数据集。".formatted(tableName)); + details.add("本次共返回 %d 行、%d 个字段。".formatted(resultSet.getData() == null ? 0 : resultSet.getData().size(), + resultSet.getColumn() == null ? 0 : resultSet.getColumn().size())); + details.add("预览结果仅包含当前 Agent 可见字段。"); + details.add("预览查询使用了 limit=%d。".formatted(limit)); + details.add("禁止根据未返回字段推断隐藏信息。"); + return List.copyOf(details); + } + + private List buildSearchResultScopeDetails(ResultSetBO resultSet, int limit, Set allowedHeaders) { + List details = new ArrayList<>(); + int rowCount = resultSet.getData() == null ? 0 : resultSet.getData().size(); + int columnCount = resultSet.getColumn() == null ? 0 : resultSet.getColumn().size(); + details.add("当前结果共返回 %d 行、%d 个字段。".formatted(rowCount, columnCount)); + details.add("查询结果最多展示 %d 行。".formatted(limit)); + if (allowedHeaders != null && !allowedHeaders.isEmpty()) { + details.add("最终仅保留 SQL select 列表中显式声明的返回字段。"); + details.add("允许返回的字段有:%s。".formatted(String.join("、", allowedHeaders))); + } + else { + details.add("当前 SQL 未触发额外的结果列过滤。"); + } + details.add("所有结果仍受字段可见性约束限制。"); + details.add("禁止根据未返回字段推断隐藏信息。"); + return List.copyOf(details); + } + private List> toRows(ResultSetBO resultSet) { return resultSet.getData().stream().map(row -> { Map mappedRow = new LinkedHashMap<>(); @@ -864,7 +989,7 @@ private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(Expl .summary(summary); } - private DatasourceExplorerResult capture(DatasourceExplorerResult result, @Nullable GraphRequest graphRequest) { + private DatasourceExplorerResult capture(DatasourceExplorerResult result, @Nullable AgentRequest graphRequest) { if (graphRequest != null) { answerTraceExplainStore.recordDatasourceResult(graphRequest, result); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index 5392af374..a157b2f9c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.datasource; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; @@ -173,8 +173,8 @@ public String call(String toolInput, ToolContext toolContext) { try { DatasourceExplorerRequest request = objectMapper.readValue(toolInput, DatasourceExplorerRequest.class); validateRequest(request); - GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(datasourceExplorerService.execute(agentId, request, graphRequest)); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper.writeValueAsString(datasourceExplorerService.execute(agentId, request, agentRequest)); } catch (Exception ex) { throw new IllegalStateException( diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index 431161fd4..2a217054f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.knowledge; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; @@ -130,8 +130,8 @@ public String call(String toolInput, ToolContext toolContext) { DomainKnowledgeSearchRequest request = new DomainKnowledgeSearchRequest(query, knowledgeTypes.isEmpty() ? null : List.copyOf(knowledgeTypes), topK, similarityThreshold); - GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(domainKnowledgeSearchService.search(agentId, request, graphRequest)); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper.writeValueAsString(domainKnowledgeSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { throw new IllegalStateException(objectToJson( diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java index 6cd2efd68..023fd20bf 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.semantic; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.SemanticModel; import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; @@ -51,7 +51,7 @@ public SemanticModelSearchResult search(String agentId, SemanticModelSearchReque } public SemanticModelSearchResult search(String agentId, SemanticModelSearchRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest agentRequest) { if (!StringUtils.hasText(agentId)) { return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); @@ -64,7 +64,7 @@ public SemanticModelSearchResult search(String agentId, SemanticModelSearchReque return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); } - return search(parsedAgentId, request, graphRequest); + return search(parsedAgentId, request, agentRequest); } public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest request) { @@ -72,7 +72,7 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest } public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest agentRequest) { String query = request == null ? null : request.getQuery(); if (!StringUtils.hasText(query)) { throw new IllegalArgumentException("semantic_model.search 需要 query 参数"); @@ -118,8 +118,8 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest List hits = scoredHits.stream().map(this::toHit).toList(); String summary = "共匹配到 %d 条补充语义提示。这些结果只用于补充理解表和字段语义,不能替代数据源探索工具的物理结构探索。" .formatted(hits.size()); - if (graphRequest != null) { - answerTraceExplainStore.recordSemanticSearch(graphRequest, query, + if (agentRequest != null) { + answerTraceExplainStore.recordSemanticSearch(agentRequest, query, "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); } else { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java index e1eda9b31..417b66490 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.semantic; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; @@ -99,8 +99,8 @@ public String call(String toolInput, ToolContext toolContext) { ? objectMapper.readValue(toolInput, SemanticModelSearchRequest.class) : new SemanticModelSearchRequest(); validateRequest(request); - GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(semanticModelSearchService.search(agentId, request, graphRequest)); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper.writeValueAsString(semanticModelSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { throw new IllegalStateException( diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java index 4ba1b3c2b..0c2f51e58 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; @@ -175,11 +175,11 @@ private void enrichRequestFromToolContext(SqlGuardCheckRequest request, ToolCont if (request == null || StringUtils.hasText(request.getHumanFeedbackContent())) { return; } - GraphRequest graphRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - if (graphRequest == null) { + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + if (agentRequest == null) { return; } - request.setHumanFeedbackContent(graphRequest.getHumanFeedbackContent()); + request.setHumanFeedbackContent(agentRequest.getHumanFeedbackContent()); } private String objectToJson(Object value) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java similarity index 84% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java index 5c6cb8483..7c2d7709e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java @@ -26,7 +26,7 @@ @AllArgsConstructor @NoArgsConstructor @Builder -public class GraphNodeResponse { +public class AgentResponse { private String agentId; @@ -46,8 +46,8 @@ public class GraphNodeResponse { @Builder.Default private boolean complete = false; - public static GraphNodeResponse error(String agentId, String threadId, String text) { - return GraphNodeResponse.builder() + public static AgentResponse error(String agentId, String threadId, String text) { + return AgentResponse.builder() .agentId(agentId) .threadId(threadId) .text(text) @@ -56,8 +56,8 @@ public static GraphNodeResponse error(String agentId, String threadId, String te .build(); } - public static GraphNodeResponse complete(String agentId, String threadId) { - return GraphNodeResponse.builder() + public static AgentResponse complete(String agentId, String threadId) { + return AgentResponse.builder() .agentId(agentId) .threadId(threadId) .complete(true) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java index 4b36a64f7..c1894509a 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java @@ -141,6 +141,16 @@ public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId" .orElseGet(() -> ResponseEntity.notFound().build()); } + @GetMapping("/sessions/{sessionId}/answers/latest/explain") + public ResponseEntity getLatestAnswerExplain(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); + return answerTraceExplainStore.getLatestExplain(sessionId) + .>map(ResponseEntity::ok) + .or(() -> loadLatestPersistedAnswerExplain(sessionId, agentId).map(ResponseEntity::ok)) + .orElseGet(() -> ResponseEntity.notFound().build()); + } + @GetMapping("/sessions/{sessionId}/answers/{runtimeRequestId}/explain") public ResponseEntity getAnswerExplain(@PathVariable(value = "sessionId") String sessionId, @RequestParam(value = "agentId") Long agentId, @@ -152,6 +162,31 @@ public ResponseEntity getAnswerExplain(@PathVariable(value = "sessionId") Str .orElseGet(() -> ResponseEntity.notFound().build()); } + private java.util.Optional loadLatestPersistedAnswerExplain(String sessionId, Long agentId) { + List snapshots = chatMessageService.findBySessionIdAndMessageType(sessionId, + ANSWER_EXPLAIN_MESSAGE_TYPE, agentId); + JsonNode latestExplainNode = null; + long latestUpdatedAt = Long.MIN_VALUE; + for (ChatMessage snapshot : snapshots) { + if (snapshot == null || !StringUtils.hasText(snapshot.getContent())) { + continue; + } + try { + JsonNode explainNode = objectMapper.readTree(snapshot.getContent()); + long updatedAt = explainNode.path("updatedAt").asLong(Long.MIN_VALUE); + if (latestExplainNode == null || updatedAt >= latestUpdatedAt) { + latestExplainNode = explainNode; + latestUpdatedAt = updatedAt; + } + } + catch (Exception ex) { + log.warn("Failed to parse persisted answer explain snapshot. sessionId={}, messageId={}", sessionId, + snapshot.getId(), ex); + } + } + return java.util.Optional.ofNullable(latestExplainNode); + } + private java.util.Optional loadPersistedAnswerExplain(String sessionId, String runtimeRequestId, Long agentId) { if (!StringUtils.hasText(sessionId) || !StringUtils.hasText(runtimeRequestId)) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java index 377c75caa..be3399e25 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java @@ -15,9 +15,9 @@ */ package com.alibaba.cloud.ai.dataagent.controller; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentService; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; @@ -44,20 +44,20 @@ public class DataAgentController { private final AgentService agentService; @GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux> streamSearch(@RequestParam("agentId") String agentId, - @RequestParam(value = "threadId", required = false) String threadId, - @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, - @RequestParam("query") String query, - @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, - @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, - @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, - @RequestParam(value = "nl2sqlOnly", required = false) boolean nl2sqlOnly, ServerHttpResponse response) { + public Flux> streamSearch(@RequestParam("agentId") String agentId, + @RequestParam(value = "threadId", required = false) String threadId, + @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, + @RequestParam("query") String query, + @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, + @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, + @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, + @RequestParam(value = "nl2sqlOnly", required = false) boolean nl2sqlOnly, ServerHttpResponse response) { response.getHeaders().add("Cache-Control", "no-cache"); response.getHeaders().add("Connection", "keep-alive"); response.getHeaders().add("Access-Control-Allow-Origin", "*"); - Sinks.Many> sink = Sinks.many().unicast().onBackpressureBuffer(); - GraphRequest request = GraphRequest.builder() + Sinks.Many> sink = Sinks.many().unicast().onBackpressureBuffer(); + AgentRequest request = AgentRequest.builder() .agentId(agentId) .threadId(threadId) .runtimeRequestId(runtimeRequestId) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java index 5b087b85b..48ecb0f07 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.observability; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService.QueryClarifyAssessment; import com.alibaba.cloud.ai.dataagent.agentscope.tool.datasource.DatasourceExplorerResult; import com.alibaba.cloud.ai.dataagent.agentscope.tool.semantic.SemanticModelSearchHit; @@ -49,7 +49,7 @@ public class AnswerTraceExplainStore { private final LinkedHashMap> explainsBySession = new LinkedHashMap<>( 32, 0.75f, true); - public void openScope(GraphRequest request) { + public void openScope(AgentRequest request) { if (request == null || !StringUtils.hasText(request.getThreadId()) || !StringUtils.hasText(request.getRuntimeRequestId())) { return; @@ -84,7 +84,7 @@ public void recordWarning(String warning) { }); } - public void recordClarifyAssessment(GraphRequest request, QueryClarifyAssessment assessment) { + public void recordClarifyAssessment(AgentRequest request, QueryClarifyAssessment assessment) { if (assessment == null) { return; } @@ -95,7 +95,7 @@ public void recordSemanticSearch(String query, String summary, List applySemanticSearch(assembly, query, summary, hits)); } - public void recordSemanticSearch(GraphRequest request, String query, String summary, + public void recordSemanticSearch(AgentRequest request, String query, String summary, List hits) { withAssembly(request, assembly -> applySemanticSearch(assembly, query, summary, hits)); } @@ -107,7 +107,7 @@ public void recordKnowledgeSearch(DomainKnowledgeSearchResult result) { withCurrentAssembly(assembly -> applyKnowledgeSearch(assembly, result)); } - public void recordKnowledgeSearch(GraphRequest request, DomainKnowledgeSearchResult result) { + public void recordKnowledgeSearch(AgentRequest request, DomainKnowledgeSearchResult result) { if (result == null) { return; } @@ -121,7 +121,7 @@ public void recordDatasourceResult(DatasourceExplorerResult result) { withCurrentAssembly(assembly -> applyDatasourceResult(assembly, result)); } - public void recordDatasourceResult(GraphRequest request, DatasourceExplorerResult result) { + public void recordDatasourceResult(AgentRequest request, DatasourceExplorerResult result) { if (result == null) { return; } @@ -145,6 +145,23 @@ public Optional getExplain(String sessionId, String runt } } + public Optional getLatestExplain(String sessionId) { + if (!StringUtils.hasText(sessionId)) { + return Optional.empty(); + } + synchronized (monitor) { + LinkedHashMap explainsByRequest = explainsBySession.get(sessionId); + if (explainsByRequest == null || explainsByRequest.isEmpty()) { + return Optional.empty(); + } + ExplainAssembly latestAssembly = explainsByRequest.values() + .stream() + .max(java.util.Comparator.comparingLong(assembly -> assembly.updatedAt)) + .orElse(null); + return latestAssembly == null ? Optional.empty() : Optional.of(latestAssembly.toView()); + } + } + public Optional getMirrorSummary(String sessionId, String runtimeRequestId) { return getExplain(sessionId, runtimeRequestId).map(explain -> ExplainMirrorSummary.builder() .datasource(explain.getDatasource()) @@ -165,7 +182,7 @@ private void withCurrentAssembly(java.util.function.Consumer co } } - private void withAssembly(GraphRequest request, java.util.function.Consumer consumer) { + private void withAssembly(AgentRequest request, java.util.function.Consumer consumer) { if (request == null || !StringUtils.hasText(request.getThreadId()) || !StringUtils.hasText(request.getRuntimeRequestId())) { return; @@ -184,7 +201,7 @@ private ExplainAssembly resolveAssemblyLocked(String sessionId, String runtimeRe return explainsByRequest.computeIfAbsent(runtimeRequestId, ignored -> new ExplainAssembly()); } - private void applyRequestContext(ExplainAssembly assembly, GraphRequest request) { + private void applyRequestContext(ExplainAssembly assembly, AgentRequest request) { if (assembly == null || request == null) { return; } @@ -268,6 +285,35 @@ private void applyDatasourceResult(ExplainAssembly assembly, DatasourceExplorerR if (StringUtils.hasText(result.getSql())) { assembly.sql = result.getSql(); } + if (result.getUsedTables() != null && !result.getUsedTables().isEmpty()) { + assembly.usedTables.clear(); + assembly.usedTables.addAll(result.getUsedTables().stream().filter(StringUtils::hasText).map(String::trim).toList()); + } + if (result.getUsedColumns() != null && !result.getUsedColumns().isEmpty()) { + assembly.usedColumns.clear(); + assembly.usedColumns + .addAll(result.getUsedColumns().stream().filter(StringUtils::hasText).map(String::trim).toList()); + } + if (result.getRelationEvidence() != null && !result.getRelationEvidence().isEmpty()) { + assembly.relationEvidence.clear(); + assembly.relationEvidence.addAll(result.getRelationEvidence()); + } + if (StringUtils.hasText(result.getResultScope())) { + assembly.resultScope = result.getResultScope(); + } + if (StringUtils.hasText(result.getDecisionReason())) { + assembly.decisionReason = result.getDecisionReason(); + } + if (result.getToolDecisionReasons() != null && !result.getToolDecisionReasons().isEmpty()) { + assembly.toolDecisionReasons.clear(); + assembly.toolDecisionReasons + .addAll(result.getToolDecisionReasons().stream().filter(StringUtils::hasText).map(String::trim).toList()); + } + if (result.getResultScopeDetails() != null && !result.getResultScopeDetails().isEmpty()) { + assembly.resultScopeDetails.clear(); + assembly.resultScopeDetails + .addAll(result.getResultScopeDetails().stream().filter(StringUtils::hasText).map(String::trim).toList()); + } assembly.toolSteps.add(ToolStepView.builder() .toolName("datasource.explorer") .title(result.getAction()) @@ -285,6 +331,9 @@ private void applyClarifyAssessment(ExplainAssembly assembly, QueryClarifyAssess assembly.clarify.put("missingDimensions", assessment.missingDimensions()); assembly.clarify.put("followUpQuestions", assessment.followUpQuestions()); assembly.clarify.put("suggestedAssumptions", assessment.suggestedAssumptions()); + assembly.clarify.put("summary", assessment.summary()); + assembly.clarify.put("userMessage", assessment.userMessage()); + assembly.clarify.put("shouldBlockExecution", assessment.shouldBlockExecution()); if (StringUtils.hasText(assessment.feedbackContent())) { assembly.clarify.put("humanFeedbackContent", assessment.feedbackContent()); } @@ -335,6 +384,20 @@ private static final class ExplainAssembly { private String sql; + private String decisionReason; + + private String resultScope; + + private final List usedTables = new ArrayList<>(); + + private final List usedColumns = new ArrayList<>(); + + private final List> relationEvidence = new ArrayList<>(); + + private final List toolDecisionReasons = new ArrayList<>(); + + private final List resultScopeDetails = new ArrayList<>(); + private final List semanticHits = new ArrayList<>(); private final List knowledgeHits = new ArrayList<>(); @@ -356,6 +419,13 @@ private AnswerTraceExplainView toView() { .answer(answer) .datasource(datasource) .sql(sql) + .decisionReason(decisionReason) + .resultScope(resultScope) + .usedTables(List.copyOf(usedTables)) + .usedColumns(List.copyOf(usedColumns)) + .relationEvidence(List.copyOf(relationEvidence)) + .toolDecisionReasons(List.copyOf(toolDecisionReasons)) + .resultScopeDetails(List.copyOf(resultScopeDetails)) .semanticHits(List.copyOf(semanticHits)) .knowledgeHits(List.copyOf(knowledgeHits)) .toolSteps(List.copyOf(toolSteps)) @@ -386,6 +456,25 @@ public static class AnswerTraceExplainView { private String sql; + private String decisionReason; + + private String resultScope; + + @Builder.Default + private List usedTables = List.of(); + + @Builder.Default + private List usedColumns = List.of(); + + @Builder.Default + private List> relationEvidence = List.of(); + + @Builder.Default + private List toolDecisionReasons = List.of(); + + @Builder.Default + private List resultScopeDetails = List.of(); + @Builder.Default private List semanticHits = List.of(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java index 7d84e9508..181452cba 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchService.java @@ -15,7 +15,8 @@ */ package com.alibaba.cloud.ai.dataagent.service.knowledge; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; + import java.util.List; import org.springframework.lang.Nullable; @@ -24,7 +25,7 @@ public interface DomainKnowledgeSearchService { DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchRequest request); DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchRequest request, - @Nullable GraphRequest graphRequest); + @Nullable AgentRequest agentRequest); record DomainKnowledgeSearchRequest(String query, List knowledgeTypes, Integer topK, Double similarityThreshold) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java index 343bc2691..92b00ad09 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.service.knowledge; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.constant.DocumentMetadataConstant; import com.alibaba.cloud.ai.dataagent.entity.AgentKnowledge; import com.alibaba.cloud.ai.dataagent.entity.BusinessKnowledge; @@ -25,9 +25,7 @@ import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; import com.alibaba.cloud.ai.dataagent.service.vectorstore.AgentVectorStoreService; import com.alibaba.cloud.ai.dataagent.service.vectorstore.DynamicFilterService; -import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchRequest; -import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchResult; -import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.KnowledgeHit; + import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; @@ -77,7 +75,7 @@ public DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchR @Override public DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchRequest request, - @Nullable GraphRequest graphRequest) { + @Nullable AgentRequest agentRequest) { Assert.hasText(agentId, "AgentId cannot be empty"); Assert.notNull(request, "Search request cannot be null"); String query = requireText(request.query(), "Query cannot be blank"); @@ -129,8 +127,8 @@ public DomainKnowledgeSearchResult search(String agentId, DomainKnowledgeSearchR String resolution = hits.isEmpty() ? "no_match" : "matched"; DomainKnowledgeSearchResult result = new DomainKnowledgeSearchResult(List.copyOf(hits), List.copyOf(warnings), resolution); - if (graphRequest != null) { - answerTraceExplainStore.recordKnowledgeSearch(graphRequest, result); + if (agentRequest != null) { + answerTraceExplainStore.recordKnowledgeSearch(agentRequest, result); } else { answerTraceExplainStore.recordKnowledgeSearch(result); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java index 00f91cd34..1d71786b5 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.service.langfuse; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; @@ -78,7 +78,7 @@ public LangfuseService(@Qualifier("langfuseTracer") Tracer langfuseTracer, /** * 开始一个 Graph 流式处理的 Span,记录完整的请求上下文 */ - public Span startLLMSpan(String spanName, GraphRequest request) { + public Span startLLMSpan(String spanName, AgentRequest request) { if (!enabled) { return Span.getInvalid(); } From aa48d1b53c42c25cf08307287fe7c0b4f3bdfa76 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 29 Apr 2026 17:44:08 +0800 Subject: [PATCH 18/22] feat: remove something --- data-agent-frontend/src/services/graph.ts | 2 - data-agent-frontend/src/views/AgentRun.vue | 1 - .../agentscope/dto/AgentRequest.java | 2 - .../runtime/AgentRuntimeExtensionFactory.java | 7 +- .../runtime/AgentRuntimeRequestMetadata.java | 2 +- .../runtime/AgentScopeHookFactory.java | 14 +- .../runtime/AgentScopeMemoryFactory.java | 53 ++---- .../AgentScopeMemoryPersistenceHook.java | 87 --------- .../runtime/AgentScopeStreamingHook.java | 14 +- .../agentscope/runtime/HumanFeedbackHook.java | 3 - .../agentscope/runtime/PreparedMemory.java | 22 +++ .../runtime/ToolContextRequestResolver.java | 1 - .../agentscope/service/AgentService.java | 2 - .../impl/AiAgentRuntimeServiceImpl.java | 162 +++++++++++++---- ...egistry.java => AgentRuntimeRegistry.java} | 2 +- .../session/AgentScopeMysqlSession.java | 167 ++++++++++++++++++ .../AgentScopeNativeSessionService.java | 103 +++++++++++ .../DomainBusinessKnowledgeToolSupport.java | 6 +- .../config/DataAgentConfiguration.java | 2 +- .../controller/DataAgentController.java | 22 ++- .../dataagent/mapper/ChatMessageMapper.java | 61 +++++-- .../service/chat/ChatMessageService.java | 5 - .../service/chat/ChatMessageServiceImpl.java | 9 - .../service/chat/ChatSessionServiceImpl.java | 15 +- .../service/datasource/DatasourceService.java | 8 - .../impl/DatasourceServiceImpl.java | 17 -- .../DomainKnowledgeSearchServiceImpl.java | 2 +- .../service/langfuse/LangfuseService.java | 40 ++--- .../ai/dataagent/service/llm/LlmService.java | 7 - .../service/mcp/McpServerService.java | 23 +-- .../ai/dataagent/util/ChatResponseUtil.java | 7 - docs/ADVANCED_FEATURES-en.md | 47 +---- docs/ADVANCED_FEATURES.md | 45 +---- docs/ARCHITECTURE-en.md | 10 +- docs/ARCHITECTURE.md | 10 +- 35 files changed, 546 insertions(+), 434 deletions(-) delete mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java rename data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/{AgentSessionRegistry.java => AgentRuntimeRegistry.java} (99%) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java diff --git a/data-agent-frontend/src/services/graph.ts b/data-agent-frontend/src/services/graph.ts index f2a4fe00c..a0cb33acf 100644 --- a/data-agent-frontend/src/services/graph.ts +++ b/data-agent-frontend/src/services/graph.ts @@ -22,7 +22,6 @@ export interface AgentRequest { humanFeedback?: boolean; humanFeedbackContent?: string; rejectedPlan: boolean; - nl2sqlOnly: boolean; } export interface ClarifyMetadata { @@ -90,7 +89,6 @@ class GraphService { params.append('humanFeedbackContent', request.humanFeedbackContent); } params.append('rejectedPlan', request.rejectedPlan.toString()); - params.append('nl2sqlOnly', request.nl2sqlOnly.toString()); const url = `${API_BASE_URL}/stream/search?${params.toString()}`; diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index db35ca350..965b3621f 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -1296,7 +1296,6 @@ query: requestQuery, humanFeedback: Boolean(activeClarify), humanFeedbackContent: feedbackContent, - nl2sqlOnly: false, rejectedPlan: false, threadId: currentSession.value.id, runtimeRequestId: createRuntimeRequestId(), diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java index 0fbe4d7c1..d6f251a06 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java @@ -40,6 +40,4 @@ public class AgentRequest { private boolean rejectedPlan; - private boolean nl2sqlOnly; - } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java index 67956f934..0c1810178 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java @@ -43,12 +43,12 @@ public class AgentRuntimeExtensionFactory { private final AgentScopeSkillBoxFactory skillBoxFactory; public AgentRuntimeExtensions create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher, - Map toolCallbacks) { + Map toolCallbacks, PreparedMemory preparedMemory) { Toolkit toolkit = toolkitFactory.buildToolkit(toolCallbacks); SkillBox skillBox = skillBoxFactory.create(request.getAgentId(), toolkit); - Memory memory = memoryFactory.create(request.getThreadId()); + Memory memory = preparedMemory == null ? memoryFactory.create(request).memory() : preparedMemory.memory(); AgentRuntimeRequestMetadata requestMetadata = new AgentRuntimeRequestMetadata(request.getAgentId(), - request.getThreadId(), request.getRuntimeRequestId(), request.isNl2sqlOnly(), request.isHumanFeedback(), + request.getThreadId(), request.getRuntimeRequestId(), request.isHumanFeedback(), request.getHumanFeedbackContent()); ToolExecutionContext toolExecutionContext = ToolExecutionContext.builder() .register(requestMetadata) @@ -57,6 +57,7 @@ public AgentRuntimeExtensions create(AgentRequest request, @Nullable AgentRuntim List hooks = hookFactory.create(request, eventPublisher); Map attributes = new HashMap<>(); attributes.put("threadId", request.getThreadId()); + attributes.put("memoryLoadedFromNative", preparedMemory != null && preparedMemory.loadedFromNative()); return new AgentRuntimeExtensions(toolkit, memory, toolExecutionContext, hooks, attributes, skillBox, ""); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java index c9bed74bb..326300b35 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -public record AgentRuntimeRequestMetadata(String agentId, String threadId, String runtimeRequestId, boolean nl2sqlOnly, +public record AgentRuntimeRequestMetadata(String agentId, String threadId, String runtimeRequestId, boolean humanFeedback, String humanFeedbackContent) { } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java index fb5b0236d..15ea436db 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java @@ -16,9 +16,6 @@ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; import io.agentscope.core.hook.Hook; import java.util.ArrayList; import java.util.List; @@ -30,20 +27,11 @@ @RequiredArgsConstructor public class AgentScopeHookFactory { - private final ChatSessionService chatSessionService; - - private final ChatMessageService chatMessageService; - - private final AgentSessionRegistry sessionRegistry; - public List create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher) { List hooks = new ArrayList<>(); if (eventPublisher != null) { - hooks.add(new AgentScopeStreamingHook(request.getAgentId(), request.getThreadId(), request.isNl2sqlOnly(), - eventPublisher)); + hooks.add(new AgentScopeStreamingHook(request.getAgentId(), request.getThreadId(), eventPublisher)); } - hooks.add(new AgentScopeMemoryPersistenceHook(request.getThreadId(), request.getRuntimeRequestId(), - sessionRegistry, chatSessionService, chatMessageService)); HumanFeedbackHook humanFeedbackHook = HumanFeedbackHook.from(request); if (humanFeedbackHook != null) { hooks.add(humanFeedbackHook); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java index 746849914..40fd939ea 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java @@ -15,13 +15,9 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; import io.agentscope.core.memory.InMemoryMemory; -import io.agentscope.core.memory.Memory; -import io.agentscope.core.message.Msg; -import io.agentscope.core.message.MsgRole; -import java.util.List; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -32,46 +28,21 @@ @RequiredArgsConstructor public class AgentScopeMemoryFactory { - private static final int MEMORY_MESSAGE_LIMIT = 20; + private final AgentScopeNativeSessionService nativeSessionService; - private final ChatMessageService chatMessageService; - - public Memory create(String threadId) { + public PreparedMemory create(AgentRequest request) { InMemoryMemory memory = new InMemoryMemory(); + String threadId = request == null ? null : request.getThreadId(); if (!StringUtils.hasText(threadId)) { - return memory; - } - List history = chatMessageService.findRecentBySessionId(threadId, MEMORY_MESSAGE_LIMIT); - history.stream().map(this::toMessage).filter(msg -> msg != null).forEach(memory::addMessage); - log.debug("Loaded {} history messages into AgentScope memory, threadId={}", history.size(), threadId); - return memory; - } - - private Msg toMessage(ChatMessage message) { - if (message == null || !StringUtils.hasText(message.getContent())) { - return null; + return new PreparedMemory(memory, false); } - return Msg.builder() - .name(resolveName(message.getRole())) - .role(resolveRole(message.getRole())) - .textContent(message.getContent()) - .build(); - } - - private String resolveName(String role) { - return StringUtils.hasText(role) ? role : "user"; - } - - private MsgRole resolveRole(String role) { - if (!StringUtils.hasText(role)) { - return MsgRole.USER; + boolean loadedFromNative = nativeSessionService.loadMemoryIfExists(memory, threadId); + if (loadedFromNative) { + log.debug("Loaded AgentScope native session memory, threadId={}", threadId); + return new PreparedMemory(memory, true); } - return switch (role.trim().toLowerCase()) { - case "assistant" -> MsgRole.ASSISTANT; - case "system" -> MsgRole.SYSTEM; - case "tool" -> MsgRole.TOOL; - default -> MsgRole.USER; - }; + log.debug("No AgentScope native session memory found, threadId={}", threadId); + return new PreparedMemory(memory, false); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java deleted file mode 100644 index 4c238dc30..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.agentscope.runtime; - -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; -import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; -import io.agentscope.core.hook.Hook; -import io.agentscope.core.hook.HookEvent; -import io.agentscope.core.hook.PostCallEvent; -import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -public class AgentScopeMemoryPersistenceHook implements Hook { - - private static final String MEMORY_MESSAGE_TYPE = "memory-text"; - - private static final String MEMORY_METADATA = "{\"source\":\"agentscope\",\"visibility\":\"memory-only\"}"; - - private final String threadId; - - private final String runtimeRequestId; - - private final AgentSessionRegistry sessionRegistry; - - private final ChatSessionService chatSessionService; - - private final ChatMessageService chatMessageService; - - public AgentScopeMemoryPersistenceHook(String threadId, String runtimeRequestId, - AgentSessionRegistry sessionRegistry, ChatSessionService chatSessionService, - ChatMessageService chatMessageService) { - this.threadId = threadId; - this.runtimeRequestId = runtimeRequestId; - this.sessionRegistry = sessionRegistry; - this.chatSessionService = chatSessionService; - this.chatMessageService = chatMessageService; - } - - @Override - public int priority() { - return 200; - } - - @Override - public Mono onEvent(T event) { - if (event instanceof PostCallEvent postCallEvent) { - persistFinalMessage(postCallEvent); - } - return Mono.just(event); - } - - private void persistFinalMessage(PostCallEvent event) { - if (!StringUtils.hasText(threadId) || chatSessionService.findBySessionId(threadId) == null) { - return; - } - if (sessionRegistry.isCancelled(threadId, runtimeRequestId)) { - return; - } - String finalText = event.getFinalMessage() == null ? null : event.getFinalMessage().getTextContent(); - if (!StringUtils.hasText(finalText)) { - return; - } - chatMessageService.saveMessage(ChatMessage.builder() - .sessionId(threadId) - .role("assistant") - .content(finalText) - .messageType(MEMORY_MESSAGE_TYPE) - .metadata(MEMORY_METADATA) - .build()); - } - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java index 278ef2ea5..d28212af2 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java @@ -34,28 +34,22 @@ public class AgentScopeStreamingHook implements Hook { private static final String PLANNER_REASONING_NODE = "planner-reasoning"; - private static final String SQL_GENERATOR_REASONING_NODE = "sql-generator-reasoning"; - private final String agentId; private final String threadId; - private final boolean nl2sqlOnly; - private final AgentRuntimeEventPublisher eventPublisher; - public AgentScopeStreamingHook(String agentId, String threadId, boolean nl2sqlOnly, - AgentRuntimeEventPublisher eventPublisher) { + public AgentScopeStreamingHook(String agentId, String threadId, AgentRuntimeEventPublisher eventPublisher) { this.agentId = agentId; this.threadId = threadId; - this.nl2sqlOnly = nl2sqlOnly; this.eventPublisher = eventPublisher; } @Override public Mono onEvent(T event) { if (event instanceof ReasoningChunkEvent reasoningChunkEvent) { - emit(resolveReasoningNodeName(), reasoningChunkEvent.getIncrementalChunk().getTextContent()); + emit(PLANNER_REASONING_NODE, reasoningChunkEvent.getIncrementalChunk().getTextContent()); } else if (event instanceof PreActingEvent preActingEvent) { emit(resolveToolNodeName(preActingEvent.getToolUse().getName()), @@ -85,10 +79,6 @@ private void emit(String nodeName, String text) { .build()); } - private String resolveReasoningNodeName() { - return nl2sqlOnly ? SQL_GENERATOR_REASONING_NODE : PLANNER_REASONING_NODE; - } - private String resolveToolNodeName(String toolName) { return toolName == null || toolName.isBlank() ? "AgentScopeTool" : "tool:" + toolName; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java index 92ed531b7..c3d7eeb55 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java @@ -43,9 +43,6 @@ private HumanFeedbackHook(boolean pauseAfterPlanning, boolean replayRequested, S } public static HumanFeedbackHook from(AgentRequest request) { - if (request.isNl2sqlOnly()) { - return null; - } boolean hasFeedbackContent = StringUtils.hasText(request.getHumanFeedbackContent()); boolean requiresReplay = hasFeedbackContent || request.isRejectedPlan(); boolean requiresPause = request.isHumanFeedback() && !requiresReplay; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java new file mode 100644 index 000000000..ebf29b203 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.runtime; + +import io.agentscope.core.memory.InMemoryMemory; + +public record PreparedMemory(InMemoryMemory memory, boolean loadedFromNative) { + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java index 9f671994e..68901eaf6 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java @@ -65,7 +65,6 @@ private static AgentRequest fromMetadata(@Nullable AgentRuntimeRequestMetadata m .agentId(metadata.agentId()) .threadId(metadata.threadId()) .runtimeRequestId(metadata.runtimeRequestId()) - .nl2sqlOnly(metadata.nl2sqlOnly()) .humanFeedback(metadata.humanFeedback()) .humanFeedbackContent(metadata.humanFeedbackContent()) .build(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java index f05393e70..eb810e367 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java @@ -22,8 +22,6 @@ public interface AgentService { - String nl2sql(String naturalQuery, String agentId); - void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest); void stopStreamProcessing(String threadId, String runtimeRequestId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 1234337b7..a935dee4b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -16,14 +16,17 @@ package com.alibaba.cloud.ai.dataagent.agentscope.service.impl; import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentScopeMemoryFactory; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeEventPublisher; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeExtensionFactory; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.PreparedMemory; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService.QueryClarifyAssessment; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentScopeToolkitFactory; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentScopeModelFactory; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentService; -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentRuntimeRegistry; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRunContext; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRuntimeExtensions; import com.alibaba.cloud.ai.dataagent.agentscope.template.ManagedAgent; @@ -42,14 +45,19 @@ import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; import com.fasterxml.jackson.databind.ObjectMapper; +import io.agentscope.core.memory.InMemoryMemory; import io.agentscope.core.message.Msg; +import io.agentscope.core.message.MsgRole; import io.agentscope.core.model.Model; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Scope; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.concurrent.CancellationException; import lombok.RequiredArgsConstructor; @@ -84,7 +92,7 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private static final String AGENT_STATUS_OFFLINE = "offline"; - private final AgentSessionRegistry sessionRegistry; + private final AgentRuntimeRegistry runtimeRegistry; private final ModelConfigDataService modelConfigDataService; @@ -98,6 +106,8 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private final AgentRuntimeExtensionFactory agentRuntimeExtensionFactory; + private final AgentScopeMemoryFactory agentScopeMemoryFactory; + private final com.alibaba.cloud.ai.dataagent.service.agent.AgentService agentService; @Qualifier("agentScopeTracer") @@ -113,19 +123,7 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private final QueryClarifyService queryClarifyService; - @Override - public String nl2sql(String naturalQuery, String agentId) { - log.info("NL2SQL runtime invoked for agentId={}", agentId); - AgentRequest request = AgentRequest.builder().agentId(agentId).query(naturalQuery).nl2sqlOnly(true).build(); - initializeRuntimeRequest(request); - sessionRegistry.register(request.getThreadId(), request.getRuntimeRequestId()); - try { - return executeAgent(request); - } - finally { - sessionRegistry.finish(request.getThreadId(), request.getRuntimeRequestId()); - } - } + private final AgentScopeNativeSessionService nativeSessionService; @Override public void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest) { @@ -133,9 +131,9 @@ public void graphStreamProcess(Sinks.Many> sink, String threadId = agentRequest.getThreadId(); String runtimeRequestId = agentRequest.getRuntimeRequestId(); StreamTextTracker streamTextTracker = new StreamTextTracker(); - sessionRegistry.register(threadId, runtimeRequestId); + runtimeRegistry.register(threadId, runtimeRequestId); AgentRuntimeEventPublisher eventPublisher = response -> { - if (!sessionRegistry.isActive(threadId, runtimeRequestId)) { + if (!runtimeRegistry.isActive(threadId, runtimeRequestId)) { return; } if (response != null && response.getTextType() == TextType.TEXT @@ -146,7 +144,7 @@ public void graphStreamProcess(Sinks.Many> sink, }; Mono.fromCallable(() -> executeAgent(agentRequest, eventPublisher)) - .doFinally(signalType -> sessionRegistry.finish(threadId, runtimeRequestId)) + .doFinally(signalType -> runtimeRegistry.finish(threadId, runtimeRequestId)) .subscribeOn(Schedulers.boundedElastic()) .subscribe(result -> emitSuccess(sink, agentRequest, result, streamTextTracker), error -> emitError(sink, agentRequest, error)); @@ -154,14 +152,14 @@ public void graphStreamProcess(Sinks.Many> sink, @Override public void stopStreamProcessing(String threadId, String runtimeRequestId) { - sessionRegistry.markCancelled(threadId, runtimeRequestId); + runtimeRegistry.markCancelled(threadId, runtimeRequestId); } private void emitSuccess(Sinks.Many> sink, AgentRequest request, String result, StreamTextTracker streamTextTracker) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); - if (!sessionRegistry.isActive(threadId, runtimeRequestId)) { + if (!runtimeRegistry.isActive(threadId, runtimeRequestId)) { return; } if (shouldEmitFinalResponse(result, streamTextTracker)) { @@ -187,13 +185,13 @@ private boolean shouldEmitFinalResponse(String result, StreamTextTracker streamT private void emitError(Sinks.Many> sink, AgentRequest request, Throwable error) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); - if (sessionRegistry.isCancelled(threadId, runtimeRequestId)) { + if (runtimeRegistry.isCancelled(threadId, runtimeRequestId)) { log.info("AgentScope runtime cancelled, suppress error propagation. threadId={}, runtimeRequestId={}", threadId, runtimeRequestId); return; } log.error("AgentScope runtime failed, threadId={}", threadId, error); - if (sessionRegistry.isActive(threadId, runtimeRequestId)) { + if (runtimeRegistry.isActive(threadId, runtimeRequestId)) { String message = error.getMessage() == null ? "AgentScope 运行失败。" : error.getMessage(); sink.tryEmitNext(ServerSentEvent.builder(AgentResponse.error(request.getAgentId(), threadId, message)) .event(STREAM_EVENT_ERROR) @@ -208,7 +206,7 @@ private String executeAgent(AgentRequest request) { private void initializeRuntimeRequest(AgentRequest request) { if (!StringUtils.hasText(request.getThreadId())) { - request.setThreadId(UUID.randomUUID().toString()); + throw new IllegalArgumentException("threadId must not be empty"); } if (!StringUtils.hasText(request.getRuntimeRequestId())) { request.setRuntimeRequestId(UUID.randomUUID().toString()); @@ -216,12 +214,14 @@ private void initializeRuntimeRequest(AgentRequest request) { } private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eventPublisher) { - sessionRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); + runtimeRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); answerTraceExplainStore.openScope(request); Span rootSpan = startRuntimeSpan(request); + PreparedMemory preparedMemory = null; try { try (Scope ignored = rootSpan.makeCurrent()) { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + preparedMemory = agentScopeMemoryFactory.create(request); + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { rootSpan.setStatus(StatusCode.OK, "cancelled"); return ""; } @@ -231,7 +231,7 @@ private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eve rootSpan.setAttribute("dataagent.query_clarify.risk_level", clarifyAssessment.riskLevel().value()); rootSpan.setAttribute("dataagent.query_clarify.blocked", clarifyAssessment.shouldBlockExecution()); if (clarifyAssessment.shouldBlockExecution()) { - return blockForClarification(request, eventPublisher, rootSpan, clarifyAssessment); + return blockForClarification(request, eventPublisher, rootSpan, clarifyAssessment, preparedMemory); } Agent managedAgentConfig = resolveManagedAgent(request.getAgentId()); ModelConfigDTO modelConfig = modelConfigDataService.getActiveConfigByType(ModelType.CHAT); @@ -242,7 +242,7 @@ private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eve modelConfig.getModelName(), toolCallbacks); ManagedAgent managedAgent = managedAgentRegistry.getRequired(); AgentRuntimeExtensions runtimeExtensions = agentRuntimeExtensionFactory.create(request, eventPublisher, - toolCallbacks); + toolCallbacks, preparedMemory); Msg response; try { response = managedAgent.run(new AgentRunContext(request.getAgentId(), request.getThreadId(), model, @@ -250,7 +250,7 @@ private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eve buildUserPrompt(request), AgentRuntimeConstant.AGENT_CALL_TIMEOUT, runtimeExtensions)); } catch (RuntimeException ex) { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) && isInterruptedCancellation(ex)) { Thread.interrupted(); rootSpan.setStatus(StatusCode.OK, "cancelled"); @@ -260,10 +260,11 @@ && isInterruptedCancellation(ex)) { } throw ex; } - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { rootSpan.setStatus(StatusCode.OK, "cancelled"); return ""; } + persistNativeMemorySafely(request, runtimeExtensions.memory()); rootSpan.setStatus(StatusCode.OK); String answer = extractText(response); answerTraceExplainStore.recordFinalAnswer(answer); @@ -278,7 +279,7 @@ && isInterruptedCancellation(ex)) { } finally { rootSpan.end(); - sessionRegistry.clearRunning(request.getThreadId(), request.getRuntimeRequestId()); + runtimeRegistry.clearRunning(request.getThreadId(), request.getRuntimeRequestId()); answerTraceExplainStore.closeScope(); } } @@ -289,7 +290,6 @@ private Span startRuntimeSpan(AgentRequest request) { span.setAttribute(SessionTraceStore.ATTR_RUNTIME_REQUEST_ID, request.getRuntimeRequestId()); span.setAttribute(SessionTraceStore.ATTR_AGENT_ID, request.getAgentId() == null ? "" : request.getAgentId()); span.setAttribute("dataagent.runtime.human_feedback", request.isHumanFeedback()); - span.setAttribute("dataagent.runtime.nl2sql_only", request.isNl2sqlOnly()); return span; } @@ -303,8 +303,10 @@ private void recordRuntimeFailure(Span rootSpan, Throwable throwable) { } private String blockForClarification(AgentRequest request, AgentRuntimeEventPublisher eventPublisher, Span rootSpan, - QueryClarifyAssessment clarifyAssessment) { + QueryClarifyAssessment clarifyAssessment, PreparedMemory preparedMemory) { String clarifyText = clarifyAssessment.userMessage(); + appendClarifyTurn(preparedMemory, request, clarifyText); + persistNativeMemorySafely(request, preparedMemory == null ? null : preparedMemory.memory()); answerTraceExplainStore.recordFinalAnswer(clarifyText); persistAnswerExplainSnapshot(request); mirrorExplainSummary(rootSpan, request); @@ -408,6 +410,102 @@ private String buildUserPrompt(AgentRequest request) { return request.getQuery() == null ? "" : request.getQuery(); } + private void persistNativeMemorySafely(AgentRequest request, io.agentscope.core.memory.Memory memory) { + if (!(memory instanceof InMemoryMemory inMemoryMemory) || request == null + || !StringUtils.hasText(request.getThreadId())) { + return; + } + normalizeMemoryForPersistence(request, inMemoryMemory); + try { + nativeSessionService.saveMemory(inMemoryMemory, request.getThreadId()); + } + catch (RuntimeException ex) { + log.warn("Failed to persist AgentScope native memory. threadId={}, runtimeRequestId={}", + request.getThreadId(), request.getRuntimeRequestId(), ex); + } + } + + private void appendClarifyTurn(PreparedMemory preparedMemory, AgentRequest request, String clarifyText) { + if (preparedMemory == null || request == null) { + return; + } + InMemoryMemory memory = preparedMemory.memory(); + if (memory == null) { + return; + } + String userInput = resolvePersistedUserInput(request); + if (StringUtils.hasText(userInput)) { + memory.addMessage(Msg.builder().name("user").role(MsgRole.USER).textContent(userInput).build()); + } + if (StringUtils.hasText(clarifyText)) { + memory.addMessage(Msg.builder().name("assistant").role(MsgRole.ASSISTANT).textContent(clarifyText).build()); + } + } + + private void normalizeMemoryForPersistence(AgentRequest request, InMemoryMemory memory) { + if (request == null || memory == null || !StringUtils.hasText(request.getHumanFeedbackContent()) + || !StringUtils.hasText(request.getQuery())) { + return; + } + List originalMessages = memory.getMessages(); + if (originalMessages == null || originalMessages.isEmpty()) { + return; + } + List normalizedMessages = new ArrayList<>(originalMessages); + int replacementIndex = findReplayQueryIndex(normalizedMessages, request.getQuery()); + if (replacementIndex < 0) { + return; + } + Msg replayMessage = normalizedMessages.get(replacementIndex); + normalizedMessages.set(replacementIndex, + Msg.builder() + .name(resolveMsgName(replayMessage, "user")) + .role(MsgRole.USER) + .textContent(request.getHumanFeedbackContent()) + .build()); + if (Objects.equals(originalMessages, normalizedMessages)) { + return; + } + memory.clear(); + normalizedMessages.forEach(memory::addMessage); + } + + private int findReplayQueryIndex(List messages, String originalQuery) { + int firstMatchIndex = -1; + for (int i = 0; i < messages.size(); i++) { + Msg message = messages.get(i); + if (message == null || message.getRole() != MsgRole.USER) { + continue; + } + if (!Objects.equals(message.getTextContent(), originalQuery)) { + continue; + } + if (firstMatchIndex < 0) { + firstMatchIndex = i; + continue; + } + return i; + } + return -1; + } + + private String resolveMsgName(Msg message, String fallback) { + if (message != null && StringUtils.hasText(message.getName())) { + return message.getName(); + } + return fallback; + } + + private String resolvePersistedUserInput(AgentRequest request) { + if (request == null) { + return null; + } + if (StringUtils.hasText(request.getHumanFeedbackContent())) { + return request.getHumanFeedbackContent(); + } + return request.getQuery(); + } + private void validateAgentStatus(Agent agent, String requestAgentId) { if (agent == null) { return; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java similarity index 99% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java index 076687c3a..3b472bed4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java @@ -21,7 +21,7 @@ import org.springframework.stereotype.Component; @Component -public class AgentSessionRegistry { +public class AgentRuntimeRegistry { private final ConcurrentHashMap> requestStatesByThreadId = new ConcurrentHashMap<>(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java new file mode 100644 index 000000000..3ec32090c --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java @@ -0,0 +1,167 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.session; + +import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; +import com.alibaba.cloud.ai.dataagent.mapper.ChatMessageMapper; +import com.alibaba.cloud.ai.dataagent.mapper.ChatSessionMapper; +import io.agentscope.core.session.Session; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.State; +import io.agentscope.core.state.SessionKey; +import io.agentscope.core.util.JsonCodec; +import io.agentscope.core.util.JsonUtils; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +@Slf4j +@Component +@RequiredArgsConstructor +public class AgentScopeMysqlSession implements Session { + + static final String STATE_MESSAGE_TYPE_PREFIX = "agentscope-state:"; + + private static final String STATE_KIND = "agentscope-session-state"; + + private static final String STATE_VISIBILITY = "system-hidden"; + + private static final String STATE_ROLE = "system"; + + private final ChatSessionMapper chatSessionMapper; + + private final ChatMessageMapper chatMessageMapper; + + private final JsonCodec jsonCodec = JsonUtils.getJsonCodec(); + + @Override + public void save(SessionKey sessionKey, String moduleName, State state) { + String sessionId = sessionId(sessionKey); + assertManagedChatSession(sessionId); + String messageType = messageType(moduleName); + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId, messageType); + chatMessageMapper.insert(buildStateMessage(sessionId, messageType, moduleName, state.getClass().getName(), 0, + jsonCodec.toPrettyJson(state))); + } + + @Override + public void save(SessionKey sessionKey, String moduleName, List states) { + String sessionId = sessionId(sessionKey); + assertManagedChatSession(sessionId); + String messageType = messageType(moduleName); + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId, messageType); + if (states == null || states.isEmpty()) { + return; + } + for (int index = 0; index < states.size(); index++) { + State state = states.get(index); + if (state == null) { + continue; + } + chatMessageMapper.insert(buildStateMessage(sessionId, messageType, moduleName, state.getClass().getName(), + index, jsonCodec.toJson(state))); + } + } + + @Override + public Optional get(SessionKey sessionKey, String moduleName, Class stateClass) { + List rows = chatMessageMapper.selectStateBySessionIdAndMessageType(sessionId(sessionKey), + messageType(moduleName)); + if (rows.isEmpty()) { + return Optional.empty(); + } + ChatMessage latestRow = rows.get(rows.size() - 1); + return Optional.ofNullable(jsonCodec.fromJson(latestRow.getContent(), stateClass)); + } + + @Override + public List getList(SessionKey sessionKey, String moduleName, Class stateClass) { + return chatMessageMapper.selectStateBySessionIdAndMessageType(sessionId(sessionKey), messageType(moduleName)) + .stream() + .map(ChatMessage::getContent) + .map(content -> jsonCodec.fromJson(content, stateClass)) + .collect(Collectors.toList()); + } + + @Override + public boolean exists(SessionKey sessionKey) { + return chatMessageMapper.countAgentScopeStateBySessionId(sessionId(sessionKey)) > 0; + } + + @Override + public void delete(SessionKey sessionKey) { + chatMessageMapper.deleteAgentScopeStateBySessionId(sessionId(sessionKey)); + } + + @Override + public void delete(SessionKey sessionKey, String moduleName) { + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId(sessionKey), messageType(moduleName)); + } + + @Override + public Set listSessionKeys() { + return chatMessageMapper.selectSessionIdsWithAgentScopeState() + .stream() + .map(SimpleSessionKey::of) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + public Mono clearAllSessions() { + return Mono.fromCallable(chatMessageMapper::deleteAllAgentScopeStateMessages) + .subscribeOn(Schedulers.boundedElastic()); + } + + private void assertManagedChatSession(String sessionId) { + if (chatSessionMapper.selectBySessionId(sessionId) == null) { + throw new IllegalStateException( + "AgentScope session persistence requires an existing chat_session. sessionId=" + sessionId); + } + } + + private String sessionId(SessionKey sessionKey) { + return sessionKey.toIdentifier(); + } + + private String messageType(String moduleName) { + String encodedModuleName = java.util.Base64.getUrlEncoder() + .withoutPadding() + .encodeToString(moduleName.getBytes(StandardCharsets.UTF_8)); + return STATE_MESSAGE_TYPE_PREFIX + encodedModuleName; + } + + private ChatMessage buildStateMessage(String sessionId, String messageType, String moduleName, String stateClass, + int sequence, String content) { + Map metadata = Map.of("kind", STATE_KIND, "visibility", STATE_VISIBILITY, "moduleName", + moduleName, "stateClass", stateClass, "sequence", sequence); + return ChatMessage.builder() + .sessionId(sessionId) + .role(STATE_ROLE) + .content(content) + .messageType(messageType) + .metadata(jsonCodec.toJson(metadata)) + .build(); + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java new file mode 100644 index 000000000..64ad96225 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.session; + +import io.agentscope.core.memory.Memory; +import io.agentscope.core.session.Session; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.SessionKey; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.Objects; +import java.util.regex.Pattern; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Slf4j +@Component +public class AgentScopeNativeSessionService { + + private static final Pattern SESSION_ID_PATTERN = Pattern.compile("[A-Za-z0-9_-]+"); + + private final Session session; + + public AgentScopeNativeSessionService(AgentScopeMysqlSession session) { + this.session = session; + } + + public boolean loadMemoryIfExists(Memory memory, String sessionId) { + Objects.requireNonNull(memory, "memory must not be null"); + if (!StringUtils.hasText(sessionId)) { + return false; + } + try { + return memory.loadIfExists(session, sessionKey(sessionId)); + } + catch (RuntimeException ex) { + log.warn("Failed to load AgentScope native session state from MySQL. sessionId={}", sessionId, ex); + return false; + } + } + + public void saveMemory(Memory memory, String sessionId) { + Objects.requireNonNull(memory, "memory must not be null"); + if (!StringUtils.hasText(sessionId)) { + return; + } + memory.saveTo(session, sessionKey(sessionId)); + } + + public void deleteSessionState(String sessionId) { + if (!StringUtils.hasText(sessionId)) { + return; + } + try { + SessionKey sessionKey = sessionKey(sessionId); + if (session.exists(sessionKey)) { + session.delete(sessionKey); + } + } + catch (RuntimeException ex) { + throw new IllegalStateException("Failed to delete AgentScope session state: " + sessionId, ex); + } + } + + public void deleteSessionStates(Collection sessionIds) { + if (sessionIds == null || sessionIds.isEmpty()) { + return; + } + for (String sessionId : new LinkedHashSet<>(sessionIds)) { + deleteSessionState(sessionId); + } + } + + private SessionKey sessionKey(String sessionId) { + return SimpleSessionKey.of(normalizeSessionId(sessionId)); + } + + private String normalizeSessionId(String sessionId) { + if (!StringUtils.hasText(sessionId)) { + throw new IllegalArgumentException("sessionId 不能为空"); + } + String normalized = sessionId.trim(); + if (!SESSION_ID_PATTERN.matcher(normalized).matches()) { + throw new IllegalArgumentException("sessionId 包含非法字符: " + sessionId); + } + return normalized; + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index 2a217054f..baa1d22ba 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -44,9 +44,9 @@ public class DomainBusinessKnowledgeToolSupport { "type": "string", "description": "必填。需要检索的业务问题、指标名、术语、SOP 主题或案例主题。" }, - "knowledgeTypes": { - "type": "array", - "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。兼容旧别名 businessTerm。", + "knowledgeTypes": { + "type": "array", + "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。", "items": { "type": "string" } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java index 06e9a2038..2cf480529 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java @@ -15,8 +15,8 @@ */ package com.alibaba.cloud.ai.dataagent.config; -import com.alibaba.cloud.ai.dataagent.properties.CodeExecutorProperties; import com.alibaba.cloud.ai.dataagent.properties.AgentSkillProperties; +import com.alibaba.cloud.ai.dataagent.properties.CodeExecutorProperties; import com.alibaba.cloud.ai.dataagent.properties.DataAgentProperties; import com.alibaba.cloud.ai.dataagent.properties.FileStorageProperties; import com.alibaba.cloud.ai.dataagent.service.vectorstore.SimpleVectorStoreInitialization; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java index be3399e25..7e959fd97 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java @@ -28,8 +28,11 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ResponseStatusException; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; +import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; +import org.springframework.http.HttpStatus; import static com.alibaba.cloud.ai.dataagent.constant.Constant.STREAM_EVENT_COMPLETE; import static com.alibaba.cloud.ai.dataagent.constant.Constant.STREAM_EVENT_ERROR; @@ -43,15 +46,18 @@ public class DataAgentController { private final AgentService agentService; + private final ChatSessionService chatSessionService; + @GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux> streamSearch(@RequestParam("agentId") String agentId, - @RequestParam(value = "threadId", required = false) String threadId, + @RequestParam("threadId") String threadId, @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, @RequestParam("query") String query, @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, - @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, - @RequestParam(value = "nl2sqlOnly", required = false) boolean nl2sqlOnly, ServerHttpResponse response) { + @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, ServerHttpResponse response) { + Long numericAgentId = parseAgentId(agentId); + chatSessionService.requireSessionForAgent(threadId, numericAgentId); response.getHeaders().add("Cache-Control", "no-cache"); response.getHeaders().add("Connection", "keep-alive"); response.getHeaders().add("Access-Control-Allow-Origin", "*"); @@ -65,7 +71,6 @@ public Flux> streamSearch(@RequestParam("agentId" .humanFeedback(humanFeedback) .humanFeedbackContent(humanFeedbackContent) .rejectedPlan(rejectedPlan) - .nl2sqlOnly(nl2sqlOnly) .build(); agentService.graphStreamProcess(sink, request); @@ -92,4 +97,13 @@ public Flux> streamSearch(@RequestParam("agentId" .doOnComplete(() -> log.info("Aiagent stream completed successfully, threadId: {}", request.getThreadId())); } + private Long parseAgentId(String agentId) { + try { + return Long.valueOf(agentId); + } + catch (NumberFormatException ex) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "agentId must be numeric", ex); + } + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java index 6f3924d0d..055ca85e3 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java @@ -37,6 +37,7 @@ public interface ChatMessageMapper { SELECT * FROM chat_message WHERE session_id = #{sessionId} AND LOWER(TRIM(COALESCE(message_type, ''))) NOT IN ('memory-text', 'answer-explain') + AND LOWER(TRIM(COALESCE(message_type, ''))) NOT LIKE 'agentscope-state:%' ORDER BY create_time ASC """) List selectVisibleBySessionId(@Param("sessionId") String sessionId); @@ -51,26 +52,27 @@ List selectBySessionIdAndMessageType(@Param("sessionId") String ses @Param("messageType") String messageType); @Select(""" - SELECT * FROM ( - SELECT * FROM chat_message - WHERE session_id = #{sessionId} - AND content IS NOT NULL - AND TRIM(content) <> '' - AND ( - LOWER(TRIM(COALESCE(message_type, ''))) = 'memory-text' - OR ( - LOWER(TRIM(COALESCE(role, ''))) IN ('assistant', 'system', 'tool', 'user') - AND LOWER(TRIM(COALESCE(message_type, ''))) NOT IN ('html', 'html-report', 'markdown-report', 'result-set') - AND LOWER(TRIM(COALESCE(message_type, ''))) IN ('', 'text') - ) - ) - ORDER BY create_time DESC - LIMIT #{limit} - ) recent_memory - ORDER BY create_time ASC + SELECT * FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) + ORDER BY create_time ASC, id ASC + """) + List selectStateBySessionIdAndMessageType(@Param("sessionId") String sessionId, + @Param("messageType") String messageType); + + @Select(""" + SELECT DISTINCT session_id FROM chat_message + WHERE LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + ORDER BY session_id ASC """) - List selectRecentMemoryEligibleBySessionId(@Param("sessionId") String sessionId, - @Param("limit") int limit); + List selectSessionIdsWithAgentScopeState(); + + @Select(""" + SELECT COUNT(*) FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + """) + int countAgentScopeStateBySessionId(@Param("sessionId") String sessionId); /** * Query by id @@ -114,4 +116,25 @@ INSERT INTO chat_message (session_id, role, content, message_type, metadata, cre """) int deleteById(@Param("id") Long id); + @Delete(""" + DELETE FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) + """) + int deleteBySessionIdAndMessageType(@Param("sessionId") String sessionId, + @Param("messageType") String messageType); + + @Delete(""" + DELETE FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + """) + int deleteAgentScopeStateBySessionId(@Param("sessionId") String sessionId); + + @Delete(""" + DELETE FROM chat_message + WHERE LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + """) + int deleteAllAgentScopeStateMessages(); + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageService.java index ad31b1de7..1d1f5f31c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageService.java @@ -33,11 +33,6 @@ public interface ChatMessageService { List findVisibleBySessionId(String sessionId, Long agentId); - /** - * Get recent messages by session ID for memory loading. - */ - List findRecentBySessionId(String sessionId, int limit); - /** * Get messages by session ID and message type. */ diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageServiceImpl.java index fa8b2e6d5..e2ab03aa8 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatMessageServiceImpl.java @@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; -import java.util.Collections; import java.util.List; /** @@ -52,14 +51,6 @@ public List findVisibleBySessionId(String sessionId, Long agentId) return findVisibleBySessionId(sessionId); } - @Override - public List findRecentBySessionId(String sessionId, int limit) { - if (limit <= 0) { - return Collections.emptyList(); - } - return chatMessageMapper.selectRecentMemoryEligibleBySessionId(sessionId, limit); - } - @Override public List findBySessionIdAndMessageType(String sessionId, String messageType) { return chatMessageMapper.selectBySessionIdAndMessageType(sessionId, messageType); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java index 3bd57e2eb..a795eec99 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java @@ -15,18 +15,18 @@ */ package com.alibaba.cloud.ai.dataagent.service.chat; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; import com.alibaba.cloud.ai.dataagent.entity.ChatSession; import com.alibaba.cloud.ai.dataagent.mapper.ChatSessionMapper; +import java.time.LocalDateTime; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.web.server.ResponseStatusException; - -import java.time.LocalDateTime; -import java.util.List; -import java.util.UUID; - @Service @Slf4j @AllArgsConstructor @@ -34,6 +34,8 @@ public class ChatSessionServiceImpl implements ChatSessionService { private final ChatSessionMapper chatSessionMapper; + private final AgentScopeNativeSessionService nativeSessionService; + /** * Get session list by agent ID */ @@ -76,8 +78,10 @@ public ChatSession createSession(Integer agentId, String title, Long userId) { */ @Override public void clearSessionsByAgentId(Integer agentId) { + List sessionIds = findByAgentId(agentId).stream().map(ChatSession::getId).collect(Collectors.toList()); LocalDateTime now = LocalDateTime.now(); int updated = chatSessionMapper.softDeleteByAgentId(agentId, now); + nativeSessionService.deleteSessionStates(sessionIds); log.info("Cleared {} sessions for agent: {}", updated, agentId); } @@ -135,6 +139,7 @@ public void renameSession(String sessionId, String newTitle, Long agentId) { public void deleteSession(String sessionId) { LocalDateTime now = LocalDateTime.now(); chatSessionMapper.softDeleteById(sessionId, now); + nativeSessionService.deleteSessionState(sessionId); log.info("Deleted session: {}", sessionId); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/DatasourceService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/DatasourceService.java index 6af2dd570..4c825f96a 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/DatasourceService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/DatasourceService.java @@ -16,7 +16,6 @@ package com.alibaba.cloud.ai.dataagent.service.datasource; import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; -import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.Datasource; import com.alibaba.cloud.ai.dataagent.entity.LogicalRelation; import java.util.List; @@ -68,13 +67,6 @@ public interface DatasourceService { */ boolean testConnection(Integer id); - /** - * Get data source list associated with agent - */ - // 应该使用 AgentDatasourceService 中的方法 - @Deprecated - List getAgentDatasource(Long agentId); - List getDatasourceTables(Integer datasourceId) throws Exception; /** diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/DatasourceServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/DatasourceServiceImpl.java index 538e2d35e..1fa39d59f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/DatasourceServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/datasource/impl/DatasourceServiceImpl.java @@ -238,23 +238,6 @@ private void evictDatasourcePool(Datasource datasource) { } } - @Override - @Deprecated - public List getAgentDatasource(Long agentId) { - List adentDatasources = agentDatasourceMapper.selectByAgentIdWithDatasource(agentId); - - // Manually fill in the data source information (since MyBatis Plus does not - // directly support complex join query result mapping) - for (AgentDatasource agentDatasource : adentDatasources) { - if (agentDatasource.getDatasourceId() != null) { - Datasource datasource = datasourceMapper.selectById(agentDatasource.getDatasourceId()); - agentDatasource.setDatasource(datasource); - } - } - - return adentDatasources; - } - @Override public List getDatasourceTables(Integer datasourceId) throws Exception { log.info("Getting tables for datasource: {}", datasourceId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java index 92b00ad09..5ddf68e68 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/knowledge/DomainKnowledgeSearchServiceImpl.java @@ -317,7 +317,7 @@ private NormalizedSearchOptions normalizeOptions(DomainKnowledgeSearchRequest re addAppliedType(appliedTypes, "agentKnowledge"); filterAgentKnowledgeByType = false; } - case "businessknowledge", "business_knowledge", "businessterm", "business_term", "business" -> { + case "businessknowledge", "business_knowledge" -> { categories.add(SearchCategory.BUSINESS_TERM); addAppliedType(appliedTypes, "businessKnowledge"); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java index 1d71786b5..7698069c0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java @@ -22,26 +22,19 @@ import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import java.util.concurrent.ConcurrentHashMap; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; -import java.util.concurrent.ConcurrentHashMap; - /** - * @author zihenzzz - * @date 2026/2/16 13:54 基于 OpenTelemetry 的 Langfuse Reporter,用于追踪 LLM 调用 + * 基于 OpenTelemetry 的 Langfuse Reporter,用于追踪 LLM 调用。 */ @Slf4j @Component public class LangfuseService { - private final Tracer tracer; - - private final boolean enabled; - - // --- Span Attribute Keys --- private static final AttributeKey INPUT_VALUE = AttributeKey.stringKey("input.value"); private static final AttributeKey OUTPUT_VALUE = AttributeKey.stringKey("output.value"); @@ -50,8 +43,6 @@ public class LangfuseService { private static final AttributeKey ATTR_THREAD_ID = AttributeKey.stringKey("data_agent.thread_id"); - private static final AttributeKey ATTR_NL2SQL_ONLY = AttributeKey.booleanKey("data_agent.nl2sql_only"); - private static final AttributeKey ATTR_HUMAN_FEEDBACK = AttributeKey .booleanKey("data_agent.human_feedback"); @@ -66,9 +57,12 @@ public class LangfuseService { private static final AttributeKey ERROR_MESSAGE = AttributeKey.stringKey("error.message"); - // --- Token 累计器,按 threadId 隔离 --- private static final ConcurrentHashMap TOKEN_ACCUMULATOR = new ConcurrentHashMap<>(); + private final Tracer tracer; + + private final boolean enabled; + public LangfuseService(@Qualifier("langfuseTracer") Tracer langfuseTracer, @Value("${langfuse.enabled:true}") boolean enabled) { this.tracer = langfuseTracer; @@ -76,7 +70,7 @@ public LangfuseService(@Qualifier("langfuseTracer") Tracer langfuseTracer, } /** - * 开始一个 Graph 流式处理的 Span,记录完整的请求上下文 + * 开始一个 Graph 流式处理 Span,记录请求上下文。 */ public Span startLLMSpan(String spanName, AgentRequest request) { if (!enabled) { @@ -84,24 +78,18 @@ public Span startLLMSpan(String spanName, AgentRequest request) { } try { - Span span = tracer.spanBuilder(spanName) - .setSpanKind(SpanKind.CLIENT) - .setParent(Context.current()) - .startSpan(); + Span span = tracer.spanBuilder(spanName).setSpanKind(SpanKind.CLIENT).setParent(Context.current()).startSpan(); String inputValue = String.format( - "{\"query\":\"%s\",\"agentId\":\"%s\",\"threadId\":\"%s\",\"nl2sqlOnly\":%s,\"humanFeedback\":%s}", + "{\\\"query\\\":\\\"%s\\\",\\\"agentId\\\":\\\"%s\\\",\\\"threadId\\\":\\\"%s\\\",\\\"humanFeedback\\\":%s}", request.getQuery() != null ? request.getQuery() : "", request.getAgentId() != null ? request.getAgentId() : "", - request.getThreadId() != null ? request.getThreadId() : "", request.isNl2sqlOnly(), - request.isHumanFeedback()); + request.getThreadId() != null ? request.getThreadId() : "", request.isHumanFeedback()); span.setAttribute(INPUT_VALUE, inputValue); span.setAttribute(ATTR_AGENT_ID, request.getAgentId() != null ? request.getAgentId() : ""); span.setAttribute(ATTR_THREAD_ID, request.getThreadId() != null ? request.getThreadId() : ""); - span.setAttribute(ATTR_NL2SQL_ONLY, request.isNl2sqlOnly()); span.setAttribute(ATTR_HUMAN_FEEDBACK, request.isHumanFeedback()); - // 初始化该 threadId 的 token 累计器 if (request.getThreadId() != null) { TOKEN_ACCUMULATOR.put(request.getThreadId(), new long[] { 0, 0 }); } @@ -115,7 +103,7 @@ public Span startLLMSpan(String spanName, AgentRequest request) { } /** - * 累计 token 用量(由 FluxUtil 在处理 ChatResponse 时调用) + * 累计 token 用量。 */ public static void accumulateTokens(Object threadId, long promptTokens, long completionTokens) { if (threadId == null) { @@ -131,7 +119,7 @@ public static void accumulateTokens(Object threadId, long promptTokens, long com } /** - * 结束 Span(成功),附带累计的 token 用量 + * 成功结束 Span,并附带累计 token。 */ public void endSpanSuccess(Span span, String threadId, String output) { if (!enabled || span == null || !span.isRecording()) { @@ -152,7 +140,7 @@ public void endSpanSuccess(Span span, String threadId, String output) { } /** - * 结束 Span(失败) + * 失败结束 Span。 */ public void endSpanError(Span span, String threadId, Exception error) { if (!enabled || span == null || !span.isRecording()) { @@ -178,7 +166,7 @@ public void endSpanError(Span span, String threadId, Exception error) { } /** - * 读取并清除累计的 token,写入 span attributes + * 读取并清除累计 token,写入 span attributes。 */ private void applyAccumulatedTokens(Span span, String threadId) { if (threadId == null) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/llm/LlmService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/llm/LlmService.java index 4b74ec109..ff87f2ceb 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/llm/LlmService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/llm/LlmService.java @@ -27,13 +27,6 @@ public interface LlmService { Flux callUser(String user); - @Deprecated - default String blockToString(Flux responseFlux) { - return toStringFlux(responseFlux).collect(StringBuilder::new, StringBuilder::append) - .map(StringBuilder::toString) - .block(); - } - default Flux toStringFlux(Flux responseFlux) { return responseFlux.map(ChatResponseUtil::getText); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java index a5488d827..24b4c9c69 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java @@ -15,14 +15,11 @@ */ package com.alibaba.cloud.ai.dataagent.service.mcp; -import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentService; import com.alibaba.cloud.ai.dataagent.entity.Agent; import com.alibaba.cloud.ai.dataagent.mapper.AgentMapper; -import com.fasterxml.jackson.annotation.JsonPropertyDescription; import lombok.AllArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.stereotype.Service; -import org.springframework.util.Assert; import java.util.List; @@ -32,13 +29,7 @@ public class McpServerService { private final AgentMapper agentMapper; - - private AgentService agentService; - - public record AgentListRequest( - @JsonPropertyDescription("按状态过滤,例如 '状态:draft-待发布,published-已发布,offline-已下线") String status, - - @JsonPropertyDescription("按关键词搜索智能体名称或描述") String keyword) { + public record AgentListRequest(String status, String keyword) { } @Tool(description = "查询智能体列表,支持按状态和关键词过滤。可以根据智能体的状态(如已发布PUBLISHED、草稿DRAFT等)进行过滤,也可以通过关键词搜索智能体的名称、描述或标签。返回按创建时间降序排列的智能体列表。") @@ -46,16 +37,4 @@ public List listAgentsToolCallback(AgentListRequest agentListRequest) { return agentMapper.findByConditions(agentListRequest.status(), agentListRequest.keyword()); } - // NL2SQL 请求参数 - public record Nl2SqlRequest(@JsonPropertyDescription("自然语言查询描述,例如:'查询销售额最高的10个产品'") String naturalQuery, - @JsonPropertyDescription("智能体ID,用于指定使用哪个智能体进行NL2SQL转换") String agentId) { - } - - @Tool(description = "将自然语言查询转换为SQL语句。使用指定的智能体将用户的自然语言查询描述转换为可执行的SQL语句,支持复杂的数据查询需求。") - public String nl2SqlToolCallback(Nl2SqlRequest nl2SqlRequest) { - Assert.hasText(nl2SqlRequest.agentId(), "AgentId cannot be empty"); - Assert.hasText(nl2SqlRequest.naturalQuery(), "Natural query cannot be empty"); - return agentService.nl2sql(nl2SqlRequest.naturalQuery(), nl2SqlRequest.agentId()); - } - } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/ChatResponseUtil.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/ChatResponseUtil.java index f283a3903..c07e19416 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/ChatResponseUtil.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/util/ChatResponseUtil.java @@ -15,7 +15,6 @@ */ package com.alibaba.cloud.ai.dataagent.util; -import com.alibaba.cloud.ai.dataagent.enums.TextType; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -37,12 +36,6 @@ public static ChatResponse createPureResponse(String message) { return new ChatResponse(List.of(generation)); } - // 这样无法达到效果,先弃用。如果不得不需要这个逻辑,再重新定义 - @Deprecated - public static ChatResponse createTrimResponse(String message, TextType textType) { - return createPureResponse(message.replace(textType.getStartSign(), "").replace(textType.getEndSign(), "")); - } - public static String getText(ChatResponse chatResponse) { Generation result = chatResponse.getResult(); if (result == null) { diff --git a/docs/ADVANCED_FEATURES-en.md b/docs/ADVANCED_FEATURES-en.md index 8980214de..ee5d40833 100644 --- a/docs/ADVANCED_FEATURES-en.md +++ b/docs/ADVANCED_FEATURES-en.md @@ -102,50 +102,7 @@ spring: ### Available Tools -#### 1. nl2SqlToolCallback - -Converts natural language queries to SQL statements. - -```json -{ - "name": "nl2SqlToolCallback", - "description": "Converts natural language queries to SQL statements. Uses the specified agent to convert user's natural language query descriptions into executable SQL statements, supporting complex data query requirements.", - "inputSchema": { - "type": "object", - "properties": { - "nl2SqlRequest": { - "type": "object", - "properties": { - "agentId": { - "type": "string", - "description": "Agent ID, specifies which agent to use for NL2SQL conversion" - }, - "naturalQuery": { - "type": "string", - "description": "Natural language query description, e.g.: 'Query the top 10 products with highest sales'" - } - }, - "required": ["agentId", "naturalQuery"] - } - }, - "required": ["nl2SqlRequest"], - "additionalProperties": false - } -} -``` - -**Usage Example**: - -```json -{ - "nl2SqlRequest": { - "agentId": "agent-123", - "naturalQuery": "Query the top 10 products with highest sales in the past 30 days" - } -} -``` - -#### 2. listAgentsToolCallback +#### 1. listAgentsToolCallback Queries the agent list, supports filtering by status and keywords. @@ -484,7 +441,7 @@ DataAgent integrates [Langfuse](https://langfuse.com/) as an LLM observability p - **Request Tracing**: Records the full lifecycle of each Graph stream processing (including new queries and human feedback) - **Token Usage Tracking**: Automatically accumulates prompt tokens and completion tokens per request - **Error Tracking**: Records exception types and error messages for troubleshooting -- **Rich Metadata**: Records context attributes such as agentId, threadId, nl2sqlOnly, humanFeedback +- **Rich Metadata**: Records context attributes such as agentId, threadId, humanFeedback ### Configuration diff --git a/docs/ADVANCED_FEATURES.md b/docs/ADVANCED_FEATURES.md index 8f729b518..ca84b8802 100644 --- a/docs/ADVANCED_FEATURES.md +++ b/docs/ADVANCED_FEATURES.md @@ -102,50 +102,7 @@ spring: ### 可用工具 -#### 1. nl2SqlToolCallback - -将自然语言查询转换为SQL语句。 - -```json -{ - "name": "nl2SqlToolCallback", - "description": "将自然语言查询转换为SQL语句。使用指定的智能体将用户的自然语言查询描述转换为可执行的SQL语句,支持复杂的数据查询需求。", - "inputSchema": { - "type": "object", - "properties": { - "nl2SqlRequest": { - "type": "object", - "properties": { - "agentId": { - "type": "string", - "description": "智能体ID,用于指定使用哪个智能体进行NL2SQL转换" - }, - "naturalQuery": { - "type": "string", - "description": "自然语言查询描述,例如:'查询销售额最高的10个产品'" - } - }, - "required": ["agentId", "naturalQuery"] - } - }, - "required": ["nl2SqlRequest"], - "additionalProperties": false - } -} -``` - -**使用示例**: - -```json -{ - "nl2SqlRequest": { - "agentId": "agent-123", - "naturalQuery": "查询过去30天销售额最高的10个产品" - } -} -``` - -#### 2. listAgentsToolCallback +#### 1. listAgentsToolCallback 查询智能体列表,支持按状态和关键词过滤。 diff --git a/docs/ARCHITECTURE-en.md b/docs/ARCHITECTURE-en.md index 1a14d063b..552e429d4 100644 --- a/docs/ARCHITECTURE-en.md +++ b/docs/ARCHITECTURE-en.md @@ -527,9 +527,9 @@ sequenceDiagram #### Key Points -- **MCP**: `McpServerService` provides NL2SQL and Agent list tools, using Mcp Server Boot Starter +- **MCP**: `McpServerService` provides the Agent list tool, using Mcp Server Boot Starter - **Multi-Model Scheduling**: `ModelConfig*` configures models, `AiModelRegistry` caches current Chat/Embedding models and supports hot-swapping (only one active model per type at a time) -- **Built-in Tools**: `nl2SqlToolCallback`, `listAgentsToolCallback` +- **Built-in Tools**: `listAgentsToolCallback` #### Architecture Diagram @@ -586,9 +586,9 @@ sequenceDiagram Factory->>OpenAI: build API client OpenAI-->>Reg: model ready - MCP->>McpSvc: call tool nl2SqlToolCallback - McpSvc->>GS: nl2sql - GS-->>McpSvc: SQL result + MCP->>McpSvc: call tool listAgentsToolCallback + McpSvc->>AgentMapper: findByConditions + AgentMapper-->>McpSvc: agent list McpSvc-->>MCP: tool response ``` diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 31926e860..b33921a59 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -536,9 +536,9 @@ sequenceDiagram #### 说明要点 -- **MCP**: `McpServerService` 提供 NL2SQL 与 Agent 列表工具,使用 Mcp Server Boot Starter +- **MCP**: `McpServerService` 提供 Agent 列表工具,使用 Mcp Server Boot Starter - **多模型调度**: `ModelConfig*` 配置模型,`AiModelRegistry` 缓存当前 Chat/Embedding 模型并支持热切换(同一时间每类仅一个激活模型) -- **已内置工具**: `nl2SqlToolCallback`、`listAgentsToolCallback` +- **已内置工具**: `listAgentsToolCallback` #### 架构图 @@ -595,9 +595,9 @@ sequenceDiagram Factory->>OpenAI: build API client OpenAI-->>Reg: model ready - MCP->>McpSvc: call tool nl2SqlToolCallback - McpSvc->>GS: nl2sql - GS-->>McpSvc: SQL result + MCP->>McpSvc: call tool listAgentsToolCallback + McpSvc->>AgentMapper: findByConditions + AgentMapper-->>McpSvc: agent list McpSvc-->>MCP: tool response ``` From aa657638608e2245677fa69476d3188736ed38e7 Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 29 Apr 2026 21:40:50 +0800 Subject: [PATCH 19/22] refactor: add constant --- .../agentscope/template/CommonAgent.java | 19 ++-------- .../constant/AgentRuntimeConstant.java | 21 +++++++++++ .../constant/AgentSessionConstant.java | 37 +++++++++++++++++++ .../service/chat/ChatSessionServiceImpl.java | 4 +- .../service/chat/SessionEventPublisher.java | 5 ++- .../service/chat/SessionTitleService.java | 23 ++++++------ 6 files changed, 78 insertions(+), 31 deletions(-) create mode 100644 data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java index 7f5faec29..ae39536d7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java @@ -21,7 +21,6 @@ import io.agentscope.core.hook.Hook; import io.agentscope.core.message.Msg; import io.agentscope.core.message.MsgRole; -import io.agentscope.core.model.ExecutionConfig; import java.time.Duration; import java.util.List; import java.util.Objects; @@ -35,18 +34,6 @@ public class CommonAgent implements ManagedAgent { private static final String SYSTEM_PROMPT = PromptLoader.loadPrompt(AGENT_TYPE, "md"); - private static final int DEFAULT_MAX_ITERS = 10; - - private static final ExecutionConfig DEFAULT_MODEL_EXECUTION_CONFIG = ExecutionConfig.builder() - .timeout(Duration.ofMinutes(2)) - .maxAttempts(2) - .build(); - - private static final ExecutionConfig DEFAULT_TOOL_EXECUTION_CONFIG = ExecutionConfig.builder() - .timeout(Duration.ofSeconds(30)) - .maxAttempts(1) - .build(); - @Override public String getAgentType() { return AGENT_TYPE; @@ -61,9 +48,9 @@ public Msg run(AgentRunContext context) { .name(AGENT_TYPE) .sysPrompt(defaultSystemPrompt(context.systemPrompt(), extensions.skillInstructions())) .model(context.model()) - .maxIters(DEFAULT_MAX_ITERS) - .modelExecutionConfig(DEFAULT_MODEL_EXECUTION_CONFIG) - .toolExecutionConfig(DEFAULT_TOOL_EXECUTION_CONFIG); + .maxIters(AgentRuntimeConstant.REACT_AGENT_DEFAULT_MAX_ITERS) + .modelExecutionConfig(AgentRuntimeConstant.DEFAULT_MODEL_EXECUTION_CONFIG) + .toolExecutionConfig(AgentRuntimeConstant.DEFAULT_TOOL_EXECUTION_CONFIG); if (extensions.toolkit() != null) { builder.toolkit(extensions.toolkit()); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java index 7d8afd7b7..84fe82250 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.constant; +import io.agentscope.core.model.ExecutionConfig; import java.time.Duration; /** @@ -22,6 +23,26 @@ */ public final class AgentRuntimeConstant { + public static final int REACT_AGENT_DEFAULT_MAX_ITERS = 10; + + public static final Duration MODEL_EXECUTION_TIMEOUT = Duration.ofMinutes(2); + + public static final int MODEL_MAX_ATTEMPTS = 2; + + public static final Duration TOOL_EXECUTION_TIMEOUT = Duration.ofSeconds(30); + + public static final int TOOL_MAX_ATTEMPTS = 1; + + public static final ExecutionConfig DEFAULT_MODEL_EXECUTION_CONFIG = ExecutionConfig.builder() + .timeout(MODEL_EXECUTION_TIMEOUT) + .maxAttempts(MODEL_MAX_ATTEMPTS) + .build(); + + public static final ExecutionConfig DEFAULT_TOOL_EXECUTION_CONFIG = ExecutionConfig.builder() + .timeout(TOOL_EXECUTION_TIMEOUT) + .maxAttempts(TOOL_MAX_ATTEMPTS) + .build(); + private AgentRuntimeConstant() { } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java new file mode 100644 index 000000000..1ec9049b9 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.constant; + +import java.time.Duration; + +/** + * Agent 会话相关的共享默认配置常量。 + */ +public final class AgentSessionConstant { + + public static final String DEFAULT_SESSION_TITLE = "新会话"; + + public static final int SESSION_TITLE_MAX_LENGTH = 20; + + public static final Duration SESSION_TITLE_GENERATION_TIMEOUT = Duration.ofSeconds(15); + + public static final Duration SESSION_EVENT_HEARTBEAT_INTERVAL = Duration.ofSeconds(2); + + private AgentSessionConstant() { + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java index a795eec99..6c746e918 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.dataagent.service.chat; import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; +import com.alibaba.cloud.ai.dataagent.constant.AgentSessionConstant; import com.alibaba.cloud.ai.dataagent.entity.ChatSession; import com.alibaba.cloud.ai.dataagent.mapper.ChatSessionMapper; import java.time.LocalDateTime; @@ -66,7 +67,8 @@ public ChatSession requireSessionForAgent(String sessionId, Long agentId) { public ChatSession createSession(Integer agentId, String title, Long userId) { String sessionId = UUID.randomUUID().toString(); - ChatSession session = new ChatSession(sessionId, agentId, title != null ? title : "新会话", "active", userId); + ChatSession session = new ChatSession(sessionId, agentId, + title != null ? title : AgentSessionConstant.DEFAULT_SESSION_TITLE, "active", userId); chatSessionMapper.insert(session); log.info("Created new chat session: {} for agent: {}", sessionId, agentId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionEventPublisher.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionEventPublisher.java index d8c97e958..8487aaad8 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionEventPublisher.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionEventPublisher.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.service.chat; +import com.alibaba.cloud.ai.dataagent.constant.AgentSessionConstant; import com.alibaba.cloud.ai.dataagent.vo.SessionUpdateEvent; import lombok.extern.slf4j.Slf4j; import org.springframework.http.codec.ServerSentEvent; @@ -23,7 +24,6 @@ import reactor.core.publisher.Sinks; import reactor.core.publisher.SignalType; -import java.time.Duration; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -40,7 +40,8 @@ public class SessionEventPublisher { public Flux> register(Integer agentId) { AgentSessionSink sink = sinks.computeIfAbsent(agentId, id -> new AgentSessionSink()); - Flux> heartbeat = Flux.interval(Duration.ofSeconds(2)) + Flux> heartbeat = Flux + .interval(AgentSessionConstant.SESSION_EVENT_HEARTBEAT_INTERVAL) .map(i -> ServerSentEvent.builder().comment("heartbeat").build()); sink.increment(); log.debug("Registered subscriber for agent {}, current count: {}", agentId, sink.subscribers.get()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionTitleService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionTitleService.java index 3c46c7698..c48daf0b3 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionTitleService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/SessionTitleService.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.service.chat; +import com.alibaba.cloud.ai.dataagent.constant.AgentSessionConstant; import com.alibaba.cloud.ai.dataagent.entity.ChatSession; import com.alibaba.cloud.ai.dataagent.service.llm.LlmService; import lombok.RequiredArgsConstructor; @@ -24,7 +25,6 @@ import org.springframework.util.StringUtils; import reactor.core.publisher.Flux; -import java.time.Duration; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -38,8 +38,6 @@ @RequiredArgsConstructor public class SessionTitleService { - private static final String DEFAULT_TITLE = "新会话"; - private final ChatSessionService chatSessionService; private final SessionEventPublisher sessionEventPublisher; @@ -94,20 +92,21 @@ private void generateAndPersist(String sessionId, String userMessage) { } private boolean hasCustomTitle(ChatSession session) { - return StringUtils.hasText(session.getTitle()) && !DEFAULT_TITLE.equals(session.getTitle()); + return StringUtils.hasText(session.getTitle()) + && !AgentSessionConstant.DEFAULT_SESSION_TITLE.equals(session.getTitle()); } private String requestSummary(String userMessage) { try { String systemPrompt = """ - 你是一名对话助手,请根据用户的第一条输入生成不超过20个字的会话标题。 + 你是一名对话助手,请根据用户的第一条输入生成不超过%d个字的会话标题。 使用中文输出,避免使用标点或引号,仅保留核心主题。 - """; + """.formatted(AgentSessionConstant.SESSION_TITLE_MAX_LENGTH); String userPrompt = "用户输入:" + userMessage; Flux responseFlux = llmService.toStringFlux(llmService.call(systemPrompt, userPrompt)); return responseFlux.collect(StringBuilder::new, StringBuilder::append) .map(StringBuilder::toString) - .block(Duration.ofSeconds(15)); + .block(AgentSessionConstant.SESSION_TITLE_GENERATION_TIMEOUT); } catch (Exception ex) { log.warn("LLM title generation failed: {}", ex.getMessage()); @@ -120,18 +119,18 @@ private String normalizeTitle(String raw) { return null; } String sanitized = raw.replaceAll("[\\r\\n]+", " ").replaceAll("[\"“”]+", "").trim(); - if (sanitized.length() > 20) { - sanitized = sanitized.substring(0, 20); + if (sanitized.length() > AgentSessionConstant.SESSION_TITLE_MAX_LENGTH) { + sanitized = sanitized.substring(0, AgentSessionConstant.SESSION_TITLE_MAX_LENGTH); } return sanitized; } private String fallbackTitle(String userMessage) { String text = userMessage.replaceAll("\\s+", " ").trim(); - if (text.length() > 20) { - text = text.substring(0, 20); + if (text.length() > AgentSessionConstant.SESSION_TITLE_MAX_LENGTH) { + text = text.substring(0, AgentSessionConstant.SESSION_TITLE_MAX_LENGTH); } - return StringUtils.hasText(text) ? text : DEFAULT_TITLE; + return StringUtils.hasText(text) ? text : AgentSessionConstant.DEFAULT_SESSION_TITLE; } } From ba261dceba64b1fd23024bec27b8b4f1691c994b Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Wed, 29 Apr 2026 21:55:41 +0800 Subject: [PATCH 20/22] feat: add clarify controller --- data-agent-frontend/src/services/graph.ts | 2 ++ data-agent-frontend/src/views/AgentRun.vue | 9 +++++++++ .../cloud/ai/dataagent/agentscope/dto/AgentRequest.java | 2 ++ .../agentscope/runtime/QueryClarifyService.java | 5 +++-- .../service/impl/AiAgentRuntimeServiceImpl.java | 5 ++++- .../ai/dataagent/controller/DataAgentController.java | 2 ++ 6 files changed, 22 insertions(+), 3 deletions(-) diff --git a/data-agent-frontend/src/services/graph.ts b/data-agent-frontend/src/services/graph.ts index a0cb33acf..377903342 100644 --- a/data-agent-frontend/src/services/graph.ts +++ b/data-agent-frontend/src/services/graph.ts @@ -19,6 +19,7 @@ export interface AgentRequest { threadId?: string; runtimeRequestId?: string; query: string; + clarifyCheckEnabled?: boolean; humanFeedback?: boolean; humanFeedbackContent?: string; rejectedPlan: boolean; @@ -82,6 +83,7 @@ class GraphService { params.append('runtimeRequestId', request.runtimeRequestId); } params.append('query', request.query); + params.append('clarifyCheckEnabled', String(Boolean(request.clarifyCheckEnabled))); if (request.humanFeedback) { params.append('humanFeedback', request.humanFeedback.toString()); } diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 965b3621f..905671c23 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -214,6 +214,13 @@ :onQuestionClick="handlePresetQuestionClick" />
+
+ 开始澄清校验 + +
每页数量 > streamSearch(@RequestParam("agentId" @RequestParam("threadId") String threadId, @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, @RequestParam("query") String query, + @RequestParam(value = "clarifyCheckEnabled", required = false) boolean clarifyCheckEnabled, @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, ServerHttpResponse response) { @@ -68,6 +69,7 @@ public Flux> streamSearch(@RequestParam("agentId" .threadId(threadId) .runtimeRequestId(runtimeRequestId) .query(query) + .clarifyCheckEnabled(clarifyCheckEnabled) .humanFeedback(humanFeedback) .humanFeedbackContent(humanFeedbackContent) .rejectedPlan(rejectedPlan) From 7bc26ddff785facf8408f023b4756dbf8d49bc6a Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Thu, 30 Apr 2026 19:17:00 +0800 Subject: [PATCH 21/22] fix: ci --- .gitignore | 2 +- data-agent-frontend/src/views/AgentRun.vue | 32 ++++++++++++++----- .../impl/AiAgentRuntimeServiceImpl.java | 2 +- .../datasource/DatasourceExplorerService.java | 20 ++++++------ .../DatasourceExplorerToolProvider.java | 8 ++--- .../DomainBusinessKnowledgeToolSupport.java | 18 +++++------ .../semantic/SemanticModelSearchService.java | 25 ++++++--------- .../semantic/SemanticModelToolSupport.java | 7 ++-- .../tool/sqlguard/SqlGuardToolProvider.java | 5 ++- .../sqlguard/SqlVerifyExplainService.java | 32 ++++++++----------- .../controller/DataAgentController.java | 14 ++++---- .../dataagent/mapper/ChatMessageMapper.java | 3 +- .../AnswerTraceExplainStore.java | 11 ++++--- .../service/chat/ChatSessionServiceImpl.java | 1 + .../service/langfuse/LangfuseService.java | 5 ++- .../service/mcp/McpServerService.java | 1 + 16 files changed, 97 insertions(+), 89 deletions(-) diff --git a/.gitignore b/.gitignore index 6c7520620..40c65b79c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,4 +65,4 @@ data-agent-management/vectorstore/* # spring-ai-alibaba source spring-ai-alibaba-1.1.0.0/ .spec-workflow -.claude \ No newline at end of file +.claude diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 905671c23..97016684c 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -776,14 +776,18 @@
使用字段 - {{ answerExplain.usedColumns.join('、') }} + + {{ answerExplain.usedColumns.join('、') }} +
工具决策来源 {{ answerExplain.decisionReason }}
工具决策细节 @@ -820,11 +824,15 @@
语义模型命中 - {{ answerExplain.semanticHits.length }} 条 + + {{ answerExplain.semanticHits.length }} 条 +
RAG / 知识命中 - {{ answerExplain.knowledgeHits.length }} 条 + + {{ answerExplain.knowledgeHits.length }} 条 +
@@ -844,7 +852,8 @@ class="answer-explain-card" >
- {{ relation.sourceTable }}.{{ relation.sourceColumn }} → {{ relation.targetTable }}.{{ relation.targetColumn }} + {{ relation.sourceTable }}.{{ relation.sourceColumn }} → + {{ relation.targetTable }}.{{ relation.targetColumn }}
{{ relation.sourceType || '-' }} @@ -1954,7 +1963,9 @@ return null; }); - const latestExplainRuntimeRequestId = computed(() => sessionTrace.value?.runtimeRequestId ?? null); + const latestExplainRuntimeRequestId = computed( + () => sessionTrace.value?.runtimeRequestId ?? null, + ); const loadLatestAnswerExplain = async (options?: { visible?: boolean }) => { if (!currentSession.value) { @@ -2022,6 +2033,8 @@ } }; + void loadAnswerExplainByRuntimeRequestId; + const openLatestAnswerExplain = async () => { if (!currentSession.value) { ElMessage.warning('当前会话还没有可查看的数据来源'); @@ -2052,7 +2065,7 @@ } return value .filter((item): item is string => typeof item === 'string' && item.trim().length > 0) - .map((item) => item.trim()); + .map(item => item.trim()); }; const summarizeExplainExecution = (explain: AnswerTraceExplain | null) => { @@ -2086,7 +2099,10 @@ if (hasSemanticEvidence) { return '本轮回答没有直接查库,但命中了语义模型,用来帮助系统理解你的问题和业务字段。'; } - if (explain.toolSteps.length > 0 || (explain.clarify && Object.keys(explain.clarify).length > 0)) { + if ( + explain.toolSteps.length > 0 || + (explain.clarify && Object.keys(explain.clarify).length > 0) + ) { return '本轮回答没有形成可展示的查库明细,但系统执行过澄清或其他工具步骤,详细过程可在下方查看。'; } return '本轮回答没有访问数据库、知识库或其他可回放工具,当前结果主要来自模型直接生成。'; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 1d910b605..7ab510c12 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -156,7 +156,7 @@ public void stopStreamProcessing(String threadId, String runtimeRequestId) { } private void emitSuccess(Sinks.Many> sink, AgentRequest request, String result, - StreamTextTracker streamTextTracker) { + StreamTextTracker streamTextTracker) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); if (!runtimeRegistry.isActive(threadId, runtimeRequestId)) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index 9f4239986..370344f9c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -678,8 +678,8 @@ private List> collectRelationEvidence(ExplorerContext contex .stream() .filter(relation -> normalizedReferencedTables.contains(normalizeTableName(relation.sourceTable())) && normalizedReferencedTables.contains(normalizeTableName(relation.targetTable()))) - .map(this::toRelationEntry) - .toList(); + .map(this::toRelationEntry) + .toList(); } private String buildResultScopeSummary(ResultSetBO resultSet, int limit) { @@ -710,8 +710,8 @@ private List buildPreviewDecisionReasons(String tableName, int limit) { "预览查询会自动附带 limit=%d,避免一次返回过多行。".formatted(limit)); } - private List buildSearchDecisionReasons(List usedTables, - List> relationEvidence, int limit) { + private List buildSearchDecisionReasons(List usedTables, List> relationEvidence, + int limit) { List reasons = new ArrayList<>(); reasons.add("本轮选择执行 SQL,是因为回答需要结构化结果来支撑结论。"); if (!usedTables.isEmpty()) { @@ -917,8 +917,8 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(exactMatches.get(0)); } if (exactMatches.size() > 1) { - throw new IllegalArgumentException("表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, - String.join(", ", exactMatches))); + throw new IllegalArgumentException( + "表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); } if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { return Optional.empty(); @@ -928,15 +928,15 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(leafMatches.get(0)); } if (leafMatches.size() > 1) { - throw new IllegalArgumentException("表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, - String.join(", ", leafMatches))); + throw new IllegalArgumentException( + "表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); } return Optional.empty(); } private IllegalArgumentException buildInvisibleTableException(ExplorerContext context, String tableName) { - return new IllegalArgumentException("表 '%s' 对当前 Agent 不可见。当前可见表:%s" - .formatted(tableName, String.join(", ", context.visibleTables()))); + return new IllegalArgumentException( + "表 '%s' 对当前 Agent 不可见。当前可见表:%s".formatted(tableName, String.join(", ", context.visibleTables()))); } private boolean isSelectedTable(ExplorerContext context, String tableName) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index a157b2f9c..9b2dac44f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -96,7 +96,7 @@ public Map getToolCallbacks(String agentId) { .inputSchema(INPUT_SCHEMA) .build(); return Map.of(toolName, new AgentBoundDatasourceExplorerToolCallback(agentId, toolDefinition, - datasourceExplorerService, objectMapper)); + datasourceExplorerService, objectMapper)); } private AgentDatasource resolveActiveDatasource(String agentId) { @@ -136,8 +136,7 @@ private String buildDescription(Datasource datasource, AgentDatasource agentData 4. 如果需要写 SQL,先获取表结构和关系,再决定是否执行 SEARCH。 5. PREVIEW_ROWS 不是默认前置动作,只有样例值会实质影响 SQL 写法时才使用。 6. %s - """ - .formatted(datasource.getName(), datasource.getType(), visibleTables); + """.formatted(datasource.getName(), datasource.getType(), visibleTables); } private static final class AgentBoundDatasourceExplorerToolCallback implements ToolCallback { @@ -174,7 +173,8 @@ public String call(String toolInput, ToolContext toolContext) { DatasourceExplorerRequest request = objectMapper.readValue(toolInput, DatasourceExplorerRequest.class); validateRequest(request); AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(datasourceExplorerService.execute(agentId, request, agentRequest)); + return objectMapper + .writeValueAsString(datasourceExplorerService.execute(agentId, request, agentRequest)); } catch (Exception ex) { throw new IllegalStateException( diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index baa1d22ba..dd0dd2cf2 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -44,9 +44,9 @@ public class DomainBusinessKnowledgeToolSupport { "type": "string", "description": "必填。需要检索的业务问题、指标名、术语、SOP 主题或案例主题。" }, - "knowledgeTypes": { - "type": "array", - "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。", + "knowledgeTypes": { + "type": "array", + "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。", "items": { "type": "string" } @@ -125,18 +125,18 @@ public String call(String toolInput, ToolContext toolContext) { Integer topK = jsonNode.has("topK") && jsonNode.get("topK").canConvertToInt() ? jsonNode.get("topK").asInt() : null; Double similarityThreshold = jsonNode.has("similarityThreshold") - && jsonNode.get("similarityThreshold").isNumber() ? jsonNode.get("similarityThreshold").asDouble() - : null; + && jsonNode.get("similarityThreshold").isNumber() + ? jsonNode.get("similarityThreshold").asDouble() : null; DomainKnowledgeSearchRequest request = new DomainKnowledgeSearchRequest(query, knowledgeTypes.isEmpty() ? null : List.copyOf(knowledgeTypes), topK, similarityThreshold); AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(domainKnowledgeSearchService.search(agentId, request, agentRequest)); + return objectMapper + .writeValueAsString(domainKnowledgeSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { - throw new IllegalStateException(objectToJson( - ToolError.of(ToolErrorCode.EXECUTION_FAILED, "domain_business_knowledge.search 执行失败:" + ex.getMessage())), - ex); + throw new IllegalStateException(objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, + "domain_business_knowledge.search 执行失败:" + ex.getMessage())), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java index 023fd20bf..34dcef71e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java @@ -53,16 +53,14 @@ public SemanticModelSearchResult search(String agentId, SemanticModelSearchReque public SemanticModelSearchResult search(String agentId, SemanticModelSearchRequest request, @Nullable AgentRequest agentRequest) { if (!StringUtils.hasText(agentId)) { - return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search 需要数值型 agentId 参数。"); + return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); } Long parsedAgentId; try { parsedAgentId = Long.valueOf(agentId); } catch (NumberFormatException ex) { - return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search 需要数值型 agentId 参数。"); + return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); } return search(parsedAgentId, request, agentRequest); } @@ -84,16 +82,14 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest TableSearchScope scope = resolveTableSearchScope(activeDatasource, request == null ? null : request.getTableNames()); if (scope.isScoped() && CollectionUtils.isEmpty(scope.getTableNames())) { - return emptyResult(query, - "请求中指定的表超出了当前活动数据源对 semantic_model.search 的可见范围。"); + return emptyResult(query, "请求中指定的表超出了当前活动数据源对 semantic_model.search 的可见范围。"); } List candidates = scope.isUnbounded() ? semanticModelService.getEnabledByAgentIdAndDatasourceId(agentId, activeDatasource.getDatasourceId()) : semanticModelService.getEnabledByAgentIdAndDatasourceIdAndTableNames(agentId, activeDatasource.getDatasourceId(), scope.getTableNames()); if (CollectionUtils.isEmpty(candidates)) { - return emptyResult(query, - "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理表结构请改用数据源探索工具查看。"); + return emptyResult(query, "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理表结构请改用数据源探索工具查看。"); } List scoredHits = candidates.stream() @@ -111,20 +107,17 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest .toList(); if (scoredHits.isEmpty()) { - return emptyResult(query, - "没有匹配到补充语义提示;如果数据源探索工具已能回答物理表结构问题,就不要额外调用 semantic_model.search。"); + return emptyResult(query, "没有匹配到补充语义提示;如果数据源探索工具已能回答物理表结构问题,就不要额外调用 semantic_model.search。"); } List hits = scoredHits.stream().map(this::toHit).toList(); - String summary = "共匹配到 %d 条补充语义提示。这些结果只用于补充理解表和字段语义,不能替代数据源探索工具的物理结构探索。" - .formatted(hits.size()); + String summary = "共匹配到 %d 条补充语义提示。这些结果只用于补充理解表和字段语义,不能替代数据源探索工具的物理结构探索。".formatted(hits.size()); if (agentRequest != null) { - answerTraceExplainStore.recordSemanticSearch(agentRequest, query, - "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); + answerTraceExplainStore.recordSemanticSearch(agentRequest, query, "共匹配到 %d 条补充语义提示".formatted(hits.size()), + hits); } else { - answerTraceExplainStore.recordSemanticSearch(query, - "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); + answerTraceExplainStore.recordSemanticSearch(query, "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); } return SemanticModelSearchResult.builder().summary(summary).hits(hits).resolution("matched").build(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java index 417b66490..f9fd09cef 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java @@ -100,11 +100,12 @@ public String call(String toolInput, ToolContext toolContext) { : new SemanticModelSearchRequest(); validateRequest(request); AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); - return objectMapper.writeValueAsString(semanticModelSearchService.search(agentId, request, agentRequest)); + return objectMapper + .writeValueAsString(semanticModelSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { - throw new IllegalStateException( - objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "semantic_model.search 执行失败:" + ex.getMessage())), + throw new IllegalStateException(objectToJson( + ToolError.of(ToolErrorCode.EXECUTION_FAILED, "semantic_model.search 执行失败:" + ex.getMessage())), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java index 0c2f51e58..f1c802eef 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -148,9 +148,8 @@ private String execute(String toolInput, ToolContext toolContext) { return objectMapper.writeValueAsString(result); } catch (Exception ex) { - throw new IllegalStateException( - objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "sql_guard.check 执行失败:" + ex.getMessage())), - ex); + throw new IllegalStateException(objectToJson( + ToolError.of(ToolErrorCode.EXECUTION_FAILED, "sql_guard.check 执行失败:" + ex.getMessage())), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java index 66f0c28c9..11614b616 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -204,8 +204,7 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r List availableColumns = loadTableColumns(context, actualTableName); List visibleColumns = applyVisibleColumnRestrictions(context, actualTableName, availableColumns); if (visibleColumns.isEmpty()) { - throw new IllegalArgumentException( - "表 '%s' 在当前 Agent 下没有可见字段".formatted(actualTableName)); + throw new IllegalArgumentException("表 '%s' 在当前 Agent 下没有可见字段".formatted(actualTableName)); } List columnsToInspect = resolveColumnsToInspect(request, actualTableName, visibleColumns); int sampleLimit = normalizeProfileLimit(request == null ? null : request.getLimit()); @@ -214,8 +213,7 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r List> columnProfiles = columnsToInspect.stream() .map(column -> buildColumnProfile(context, actualTableName, column, totalRows, sampleLimit)) .toList(); - String summary = "仅基于可见字段对表 '%s' 的 %d 个字段完成 profile 分析。".formatted(columnProfiles.size(), - actualTableName); + String summary = "仅基于可见字段对表 '%s' 的 %d 个字段完成 profile 分析。".formatted(columnProfiles.size(), actualTableName); return SqlGuardCheckResult.builder() .decision("inspect_columns") .tableName(actualTableName) @@ -223,8 +221,7 @@ public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest r .totalRows(totalRows) .columnProfiles(columnProfiles) .fixSuggestions( - List.of("可优先把高频值集中的分类字段用作过滤条件或 GROUP BY 候选字段。", - "可优先把具备 min/max 范围的数值或时间字段用作指标、趋势或时间窗口候选字段。")) + List.of("可优先把高频值集中的分类字段用作过滤条件或 GROUP BY 候选字段。", "可优先把具备 min/max 范围的数值或时间字段用作指标、趋势或时间窗口候选字段。")) .build(); } @@ -248,8 +245,7 @@ private ProfileContext resolveProfileContext(String agentId) { : explicitSelectedTables; } catch (Exception ex) { - throw new IllegalStateException("加载数据源 %s 的可见表失败:%s" - .formatted(datasource.getId(), ex.getMessage()), ex); + throw new IllegalStateException("加载数据源 %s 的可见表失败:%s".formatted(datasource.getId(), ex.getMessage()), ex); } Map> visibleTablesByName = indexTables(visibleTables, false); Map> visibleTablesByLeafName = indexTables(visibleTables, true); @@ -276,8 +272,7 @@ private List loadTableColumns(ProfileContext context, String table .orElse(List.of()); } catch (Exception ex) { - throw new IllegalStateException( - "加载表 '%s' 的字段失败:%s".formatted(tableName, ex.getMessage()), ex); + throw new IllegalStateException("加载表 '%s' 的字段失败:%s".formatted(tableName, ex.getMessage()), ex); } } @@ -309,8 +304,8 @@ private List resolveColumnsToInspect(SqlGuardCheckRequest request, for (String requestedColumn : requestedColumns) { ColumnInfoBO column = columnsByName.get(normalizeColumnName(requestedColumn)); if (column == null) { - throw new IllegalArgumentException("字段 '%s' 在表 '%s' 中对当前 Agent 不可见" - .formatted(requestedColumn, tableName)); + throw new IllegalArgumentException( + "字段 '%s' 在表 '%s' 中对当前 Agent 不可见".formatted(requestedColumn, tableName)); } resolvedColumns.add(column); } @@ -542,9 +537,8 @@ private long parseLong(String value) { private String resolveVisibleTableName(ProfileContext context, String tableName) { return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) - .orElseThrow( - () -> new IllegalArgumentException("表 '%s' 对当前 Agent 不可见。当前可见表:%s" - .formatted(tableName, String.join(", ", context.visibleTables())))); + .orElseThrow(() -> new IllegalArgumentException( + "表 '%s' 对当前 Agent 不可见。当前可见表:%s".formatted(tableName, String.join(", ", context.visibleTables())))); } private Optional findVisibleTableName(Map> visibleTablesByName, @@ -555,8 +549,8 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(exactMatches.get(0)); } if (exactMatches.size() > 1) { - throw new IllegalArgumentException("表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, - String.join(", ", exactMatches))); + throw new IllegalArgumentException( + "表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); } if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { return Optional.empty(); @@ -566,8 +560,8 @@ private Optional findVisibleTableName(Map> visibleT return Optional.of(leafMatches.get(0)); } if (leafMatches.size() > 1) { - throw new IllegalArgumentException("表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, - String.join(", ", leafMatches))); + throw new IllegalArgumentException( + "表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); } return Optional.empty(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java index 8e742e29f..8658dc49c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java @@ -50,13 +50,13 @@ public class DataAgentController { @GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux> streamSearch(@RequestParam("agentId") String agentId, - @RequestParam("threadId") String threadId, - @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, - @RequestParam("query") String query, - @RequestParam(value = "clarifyCheckEnabled", required = false) boolean clarifyCheckEnabled, - @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, - @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, - @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, ServerHttpResponse response) { + @RequestParam("threadId") String threadId, + @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, + @RequestParam("query") String query, + @RequestParam(value = "clarifyCheckEnabled", required = false) boolean clarifyCheckEnabled, + @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, + @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, + @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, ServerHttpResponse response) { Long numericAgentId = parseAgentId(agentId); chatSessionService.requireSessionForAgent(threadId, numericAgentId); response.getHeaders().add("Cache-Control", "no-cache"); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java index 055ca85e3..e28d68fbf 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java @@ -121,8 +121,7 @@ INSERT INTO chat_message (session_id, role, content, message_type, metadata, cre WHERE session_id = #{sessionId} AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) """) - int deleteBySessionIdAndMessageType(@Param("sessionId") String sessionId, - @Param("messageType") String messageType); + int deleteBySessionIdAndMessageType(@Param("sessionId") String sessionId, @Param("messageType") String messageType); @Delete(""" DELETE FROM chat_message diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java index 48ecb0f07..2028d3d98 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/observability/AnswerTraceExplainStore.java @@ -287,7 +287,8 @@ private void applyDatasourceResult(ExplainAssembly assembly, DatasourceExplorerR } if (result.getUsedTables() != null && !result.getUsedTables().isEmpty()) { assembly.usedTables.clear(); - assembly.usedTables.addAll(result.getUsedTables().stream().filter(StringUtils::hasText).map(String::trim).toList()); + assembly.usedTables + .addAll(result.getUsedTables().stream().filter(StringUtils::hasText).map(String::trim).toList()); } if (result.getUsedColumns() != null && !result.getUsedColumns().isEmpty()) { assembly.usedColumns.clear(); @@ -306,13 +307,13 @@ private void applyDatasourceResult(ExplainAssembly assembly, DatasourceExplorerR } if (result.getToolDecisionReasons() != null && !result.getToolDecisionReasons().isEmpty()) { assembly.toolDecisionReasons.clear(); - assembly.toolDecisionReasons - .addAll(result.getToolDecisionReasons().stream().filter(StringUtils::hasText).map(String::trim).toList()); + assembly.toolDecisionReasons.addAll( + result.getToolDecisionReasons().stream().filter(StringUtils::hasText).map(String::trim).toList()); } if (result.getResultScopeDetails() != null && !result.getResultScopeDetails().isEmpty()) { assembly.resultScopeDetails.clear(); - assembly.resultScopeDetails - .addAll(result.getResultScopeDetails().stream().filter(StringUtils::hasText).map(String::trim).toList()); + assembly.resultScopeDetails.addAll( + result.getResultScopeDetails().stream().filter(StringUtils::hasText).map(String::trim).toList()); } assembly.toolSteps.add(ToolStepView.builder() .toolName("datasource.explorer") diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java index 6c746e918..d5b71acb2 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/chat/ChatSessionServiceImpl.java @@ -28,6 +28,7 @@ import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.web.server.ResponseStatusException; + @Service @Slf4j @AllArgsConstructor diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java index 7698069c0..eb12e41b4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/langfuse/LangfuseService.java @@ -78,7 +78,10 @@ public Span startLLMSpan(String spanName, AgentRequest request) { } try { - Span span = tracer.spanBuilder(spanName).setSpanKind(SpanKind.CLIENT).setParent(Context.current()).startSpan(); + Span span = tracer.spanBuilder(spanName) + .setSpanKind(SpanKind.CLIENT) + .setParent(Context.current()) + .startSpan(); String inputValue = String.format( "{\\\"query\\\":\\\"%s\\\",\\\"agentId\\\":\\\"%s\\\",\\\"threadId\\\":\\\"%s\\\",\\\"humanFeedback\\\":%s}", diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java index 24b4c9c69..530afc110 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/mcp/McpServerService.java @@ -29,6 +29,7 @@ public class McpServerService { private final AgentMapper agentMapper; + public record AgentListRequest(String status, String keyword) { } From b238ec3d1e5f441d86fb860fdcadfadd9ecbfd3e Mon Sep 17 00:00:00 2001 From: mengnankkkk Date: Thu, 30 Apr 2026 19:30:26 +0800 Subject: [PATCH 22/22] fix: remove button --- data-agent-frontend/src/views/AgentRun.vue | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 97016684c..04d7864d2 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -214,7 +214,7 @@ :onQuestionClick="handlePresetQuestionClick" />
-
+
开始澄清校验
-
+
需要先澄清后再查库 riskLevel={{ pendingClarify.riskLevel }} @@ -1217,6 +1217,8 @@ answerExplainVisible, pendingClarify, }); + pendingClarify.value = null; + getSessionState(session.id).pendingClarify = null; currentMessages.value = await ChatService.getSessionMessages( session.id, requireResolvedAgentId(), @@ -1287,10 +1289,8 @@ isSubmittingMessage.value = true; const needsTitle = !currentSession.value?.title || currentSession.value.title === '新会话'; - const activeClarify = pendingClarify.value; + const activeClarify: PendingClarifyState | null = null; const requestQuery = activeClarify?.originalQuery ?? userInput.value.trim(); - const feedbackContent = activeClarify ? userInput.value.trim() : undefined; - const userMessage: ChatMessage = { sessionId: currentSession.value.id, role: 'user', @@ -1311,9 +1311,9 @@ const request: AgentRequest = { agentId: String(requireResolvedAgentId()), query: requestQuery, - clarifyCheckEnabled: requestOptions.value.clarifyCheckEnabled, - humanFeedback: Boolean(activeClarify), - humanFeedbackContent: feedbackContent, + clarifyCheckEnabled: false, + humanFeedback: false, + humanFeedbackContent: undefined, rejectedPlan: false, threadId: currentSession.value.id, runtimeRequestId: createRuntimeRequestId(), @@ -1442,7 +1442,7 @@ if (sessionState.lastRequest) { sessionState.lastRequest.threadId = response.threadId; } - if (isClarifyMetadata(response.metadata)) { + if (isClarifyMetadata(response.metadata) && request.clarifyCheckEnabled) { const nextPendingClarify = buildPendingClarifyState(response.metadata); sessionState.pendingClarify = nextPendingClarify; if (currentSession.value?.id === sessionId) {