diff --git "a/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" "b/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" new file mode 100644 index 000000000..0d915af57 --- /dev/null +++ "b/.claude/context-summary-\345\267\245\345\205\267\346\217\217\350\277\260\344\270\255\346\226\207\345\214\226.md" @@ -0,0 +1,52 @@ +## 项目上下文摘要(工具描述中文化) +生成时间:2026-04-28 + +### 1. 相似实现分析 +- **实现1**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java:131-152 + - 模式:由 provider 动态构造 `ToolDefinition.description` + - 可复用:中文分点式说明风格,明确适用范围、顺序和禁用场景 + - 需注意:描述会直接影响 Agent 工具选择与调用顺序 + +- **实现2**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java:32-50 + - 模式:`INPUT_SCHEMA` 采用中文字段说明,强调“必填/可选/何时不必传” + - 可复用:schema 文案风格可直接作为其他工具输入说明参考 + - 需注意:和 provider 的 description 要保持边界一致 + +- **实现3**: data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java:37-63 + - 模式:support 负责统一 `ToolDefinition` 的 input schema + - 可复用:中文化时优先改 provider 的 DESCRIPTION,support 仅修正残留英文术语 + - 需注意:不要改动工具行为,只改面向 Agent 的说明文本 + +### 2. 项目约定 +- **命名约定**: Java 类、常量、方法名保持现有英文命名;说明文本统一简体中文 +- **文件组织**: provider 负责工具注册与 description;support 负责 input schema 和 callback 封装 +- **导入顺序**: 保持现有 import 顺序,不做无关调整 +- **代码风格**: 使用 Java 文本块 `"""` 保存多行描述,延续现有缩进和换行风格 + +### 3. 可复用组件清单 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java`:中文工具描述样板 +- `data-agent-management/src/main/resources/prompts/commonagent.md`:工具边界和调用顺序的权威规则 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java`:中文 schema 样板 +- `data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java`:中文 schema 样板 + +### 4. 测试策略 +- **测试框架**: 本次先做静态验证,不新增测试 +- **验证方式**: 搜索所有注册工具 description / input schema,确认英文描述已清理且工具边界与 `commonagent.md` 一致 +- **参考文件**: `commonagent.md`、各 ToolProvider / ToolSupport +- **覆盖要求**: 至少覆盖 `sql_guard.check`、`semantic_model.search`、`domain_business_knowledge.search` 以及已存在中文样板的对齐检查 + +### 5. 依赖和集成点 +- **外部依赖**: Spring AI `ToolDefinition` +- **内部依赖**: `AgentScopedToolProvider`、各 `ToolSupport` +- **集成方式**: provider 将 description 和 inputSchema 注入 `ToolDefinition.builder()` +- **配置来源**: `data-agent-management/src/main/resources/prompts/commonagent.md` 定义工具路由规则 + +### 6. 技术选型理由 +- **为什么用这个方案**: 直接修改 provider/support 文案即可完成目标,影响面准确且不改变运行逻辑 +- **优势**: 风险低、改动集中、与现有工具注册结构一致 +- **劣势和风险**: 只改文案不跑行为测试;需确保中文说明与现有路由规则完全一致,避免误导模型 + +### 7. 关键风险点 +- **边界条件**: `sql_guard.check` 同时支持 `SQL_VERIFY` 与 `DATA_PROFILE`,描述必须避免把 profile 说成默认步骤 +- **一致性风险**: provider description、schema description、`commonagent.md` 三处边界不能冲突 +- **维护风险**: 若遗漏某个注册工具,前后端展示和模型行为提示会出现中英混杂 diff --git a/.gitignore b/.gitignore index f52b4e24d..40c65b79c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ data-agent-management/vectorstore/* # spring-ai-alibaba source spring-ai-alibaba-1.1.0.0/ .spec-workflow +.claude diff --git a/agent-skills/builtin-domain-business-knowledge/SKILL.md b/agent-skills/builtin-domain-business-knowledge/SKILL.md index e69ef4ad2..b7ea93a5c 100644 --- a/agent-skills/builtin-domain-business-knowledge/SKILL.md +++ b/agent-skills/builtin-domain-business-knowledge/SKILL.md @@ -8,7 +8,7 @@ description: 当用户问题依赖公司内部业务规则、指标口径、SOP 工作原则: 1. 不要在每一轮都调用工具。只有当回答明显依赖业务知识,或者你不确定业务口径时,才调用 `domain_business_knowledge.search`。 -2. 默认优先检索 `businessTerm` 与 `agentKnowledge`。如果你明确知道要查 FAQ、QA 或文档,可以通过 `knowledgeTypes` 缩小范围。 +2. 默认优先检索 `businessKnowledge` 与 `agentKnowledge`。如果你明确知道要查 FAQ、QA 或文档,可以通过 `knowledgeTypes` 缩小范围。 3. 将工具返回结果作为“证据”来辅助推理,而不是机械复述;回答时要提炼规则、说明口径,并在必要时指出来源。 4. 如果检索结果为空或证据冲突,必须明确告诉用户当前知识不足,不能编造业务规则。 5. 当问题会影响 SQL 过滤条件、指标定义、分析步骤时,优先调用该工具再继续生成回答或计划。 diff --git a/data-agent-frontend/knip.json b/data-agent-frontend/knip.json index afbe789ed..d06c3bc8e 100644 --- a/data-agent-frontend/knip.json +++ b/data-agent-frontend/knip.json @@ -2,6 +2,5 @@ "$schema": "https://unpkg.com/knip@4/schema.json", "entry": ["src/main.js", "src/**/*.vue", "vite.config.js"], "project": ["src/**/*.{js,ts,vue}"], - "ignore": ["**/*.d.ts"], - "ignoreDependencies": ["@vitejs/plugin-vue", "vite"] + "ignore": ["**/*.d.ts"] } diff --git a/data-agent-frontend/src/components/agent/DataSourceConfig.vue b/data-agent-frontend/src/components/agent/DataSourceConfig.vue index da448cde8..c82d5a4dc 100644 --- a/data-agent-frontend/src/components/agent/DataSourceConfig.vue +++ b/data-agent-frontend/src/components/agent/DataSourceConfig.vue @@ -1,4 +1,4 @@ - > = ref({}); const updateLoadingStates: Ref> = ref({}); const agentDatasourceList: Ref = ref([]); + const selectedColumns: Ref>> = ref({}); + const columnOptionsByDatasource: Ref>> = ref({}); + const columnRestrictionEnabled: Ref>> = ref({}); + const columnLoadingStates: Ref> = ref({}); + const columnDialogVisible: Ref = ref(false); + const currentColumnDatasource: Ref = ref(null); + const currentColumnTables: Ref = ref([]); + const savingColumnVisibility: Ref = ref(false); // 逻辑外键管理相关状态 const foreignKeyDialogVisible: Ref = ref(false); @@ -892,9 +1039,20 @@ const datasourceItem = { ...item.datasource }; datasourceItem.status = item.isActive === 1 ? 'active' : 'inactive'; - // 初始化已选择的表 - if (item.selectTables && item.datasource?.id) { - selectedTables.value[item.datasource.id] = [...item.selectTables]; + if (item.datasource?.id) { + if (item.selectTables) { + selectedTables.value[item.datasource.id] = [...item.selectTables]; + } + selectedColumns.value[item.datasource.id] = Object.entries( + item.selectColumns || {}, + ).reduce>((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}); + columnRestrictionEnabled.value[item.datasource.id] = {}; + Object.keys(selectedColumns.value[item.datasource.id]).forEach(tableName => { + columnRestrictionEnabled.value[item.datasource.id][tableName] = true; + }); } return datasourceItem; @@ -905,6 +1063,140 @@ } }; + const getAgentDatasourceByDatasourceId = ( + datasourceId: number, + ): AgentDatasource | undefined => { + return agentDatasourceList.value.find(item => item.datasource?.id === datasourceId); + }; + + const getErrorMessage = (error: unknown, fallback: string): string => { + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; + }; + + const applyAgentDatasourceSnapshot = (snapshot: AgentDatasource): void => { + const datasourceId = snapshot.datasource?.id; + if (!datasourceId || !snapshot.datasource) { + return; + } + + const nextSnapshot: AgentDatasource = { + ...snapshot, + selectTables: [...(snapshot.selectTables || [])], + selectColumns: Object.entries(snapshot.selectColumns || {}).reduce< + Record + >((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}), + }; + + const agentDatasourceIndex = agentDatasourceList.value.findIndex( + item => item.datasource?.id === datasourceId, + ); + if (agentDatasourceIndex >= 0) { + agentDatasourceList.value[agentDatasourceIndex] = nextSnapshot; + } else { + agentDatasourceList.value.push(nextSnapshot); + } + + const datasourceSnapshot: Datasource = { + ...nextSnapshot.datasource, + status: nextSnapshot.isActive === 1 ? 'active' : 'inactive', + }; + const datasourceIndex = datasource.value.findIndex(item => item.id === datasourceId); + if (datasourceIndex >= 0) { + datasource.value[datasourceIndex] = datasourceSnapshot; + } else { + datasource.value.push(datasourceSnapshot); + } + + selectedTables.value[datasourceId] = [...(nextSnapshot.selectTables || [])]; + selectedColumns.value[datasourceId] = Object.entries( + nextSnapshot.selectColumns || {}, + ).reduce>((result, [tableName, columns]) => { + result[tableName] = [...columns]; + return result; + }, {}); + columnRestrictionEnabled.value[datasourceId] = {}; + (nextSnapshot.selectTables || []).forEach(tableName => { + columnRestrictionEnabled.value[datasourceId][tableName] = + (nextSnapshot.selectColumns?.[tableName] || []).length > 0; + }); + }; + + const getSelectedTablesForDatasource = (datasourceId: number): string[] => { + const currentTables = selectedTables.value[datasourceId]; + if (currentTables && currentTables.length > 0) { + return [...currentTables]; + } + return [...(getAgentDatasourceByDatasourceId(datasourceId)?.selectTables || [])]; + }; + + const normalizeNameList = (values: string[] = []): string[] => { + return [...values] + .map(value => value.trim()) + .filter(Boolean) + .sort((left, right) => left.localeCompare(right)); + }; + + const hasPendingTableChanges = (datasourceRow: Datasource): boolean => { + if (!datasourceRow.id) { + return false; + } + const savedTables = normalizeNameList( + getAgentDatasourceByDatasourceId(datasourceRow.id)?.selectTables || [], + ); + const currentTables = normalizeNameList(selectedTables.value[datasourceRow.id] || []); + return savedTables.join('|') !== currentTables.join('|'); + }; + + const resolveConfiguredColumns = ( + selectColumns: Record | undefined, + tableName: string, + ): string[] => { + if (!selectColumns) { + return []; + } + if (selectColumns[tableName]) { + return [...selectColumns[tableName]]; + } + const matchedKey = Object.keys(selectColumns).find( + key => key.toLowerCase() === tableName.toLowerCase(), + ); + return matchedKey ? [...(selectColumns[matchedKey] || [])] : []; + }; + + const getColumnLoadingKey = (datasourceId: number, tableName: string): string => { + return `${datasourceId}:${tableName}`; + }; + + const loadColumnsForTable = async ( + datasourceId: number, + tableName: string, + ): Promise => { + const loadingKey = getColumnLoadingKey(datasourceId, tableName); + columnLoadingStates.value[loadingKey] = true; + try { + const columns = await agentDatasourceService.getVisibleTableColumns( + String(props.agentId), + datasourceId, + tableName, + ); + if (!columnOptionsByDatasource.value[datasourceId]) { + columnOptionsByDatasource.value[datasourceId] = {}; + } + columnOptionsByDatasource.value[datasourceId][tableName] = columns; + } catch (error) { + ElMessage.error(getErrorMessage(error, `加载表 ${tableName} 的字段失败`)); + console.error('Failed to load datasource columns:', error); + } finally { + columnLoadingStates.value[loadingKey] = false; + } + }; + const handleSelectDatasourceChange = (value: Datasource) => { if (value === null || value === undefined) { selectedDatasourceId.value = null; @@ -986,20 +1278,46 @@ // 更改数据源状态 const changeDatasource = async (row: Datasource, active: boolean) => { const datasourceId = row.id; + if (!datasourceId) { + ElMessage.error('数据源ID不存在,无法切换状态'); + return; + } try { - const response: ApiResponse = await agentDatasourceService.toggleDatasourceForAgent( - props.agentId, - { datasourceId, isActive: active }, - ); - if (response.success) { - ElMessage.success('操作成功!'); - row.status = active ? 'active' : 'inactive'; + if (active) { + const response: ApiResponse = await agentDatasourceService.addDatasourceToAgent( + String(props.agentId), + datasourceId, + ); + if (response.success) { + ElMessage.success('已切换到对应数据源'); + await loadAgentDatasource(); + } else { + ElMessage.error(response.message || '切换数据源失败!'); + console.error('Failed to switch datasource:', response); + } } else { - ElMessage.error('操作失败!'); - console.error('Failed to change datasource:', response); + const activeDatasourceCount = datasource.value.filter( + item => item.status === 'active', + ).length; + if (row.status === 'active' && activeDatasourceCount <= 1) { + ElMessage.warning('当前智能体必须至少保留一个启用中的数据源'); + return; + } + + const response: ApiResponse = await agentDatasourceService.toggleDatasourceForAgent( + String(props.agentId), + { datasourceId, isActive: false }, + ); + if (response.success) { + ElMessage.success('操作成功!'); + await loadAgentDatasource(); + } else { + ElMessage.error(response.message || '操作失败!'); + console.error('Failed to disable datasource:', response); + } } } catch (error) { - ElMessage.error('操作失败!'); + ElMessage.error(getErrorMessage(error, active ? '切换数据源失败!' : '操作失败!')); console.error('Failed to change datasource:', error); } }; @@ -1007,6 +1325,10 @@ // 测试数据源连接 const testConnection = async (row: Datasource) => { const datasourceId = row.id; + if (row.status !== 'active') { + ElMessage.warning('禁用的数据源无需测试连接,请先启用'); + return; + } try { const response: ApiResponse = await datasourceService.testConnection(datasourceId); if (response.success) { @@ -1245,6 +1567,10 @@ // 加载数据源的表列表 const loadDatasourceTables = async (datasource: Datasource) => { if (!datasource.id) return; + if (datasource.status !== 'active') { + ElMessage.warning('禁用的数据源无需加载表结构,请先启用'); + return; + } tableLoadingStates.value[datasource.id] = true; try { @@ -1282,26 +1608,146 @@ }, ); - if (response.success) { + if (response.success && response.data) { + applyAgentDatasourceSnapshot(response.data); ElMessage.success('数据表更新成功'); - // 更新本地存储的已选择表 - const agentDatasource = agentDatasourceList.value.find( - item => item.datasource?.id === datasource.id, - ); - if (agentDatasource) { - agentDatasource.selectTables = [...(selectedTables.value[datasource.id] || [])]; - } } else { - ElMessage.error('数据表更新失败'); + ElMessage.error(response.message || '数据表更新失败'); } } catch (error) { - ElMessage.error('数据表更新失败'); + ElMessage.error(getErrorMessage(error, '数据表更新失败')); console.error('Failed to update datasource tables:', error); } finally { updateLoadingStates.value[datasource.id] = false; } }; + const toggleColumnRestriction = (tableName: string, enabled: boolean | string | number) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + if (!columnRestrictionEnabled.value[datasourceId]) { + columnRestrictionEnabled.value[datasourceId] = {}; + } + columnRestrictionEnabled.value[datasourceId][tableName] = Boolean(enabled); + if (!enabled) { + selectedColumns.value[datasourceId][tableName] = []; + } + }; + + const openColumnVisibilityDialog = async (datasourceRow: Datasource) => { + if (!datasourceRow.id) { + return; + } + if (hasPendingTableChanges(datasourceRow)) { + ElMessage.warning('请先点击“更新数据表”保存当前表配置,再设置字段可见性'); + return; + } + + const tables = getSelectedTablesForDatasource(datasourceRow.id); + if (tables.length === 0) { + ElMessage.warning('请先选择并保存数据表,再配置字段可见性'); + return; + } + + currentColumnDatasource.value = datasourceRow; + currentColumnTables.value = [...tables]; + + if (!selectedColumns.value[datasourceRow.id]) { + selectedColumns.value[datasourceRow.id] = {}; + } + if (!columnRestrictionEnabled.value[datasourceRow.id]) { + columnRestrictionEnabled.value[datasourceRow.id] = {}; + } + if (!columnOptionsByDatasource.value[datasourceRow.id]) { + columnOptionsByDatasource.value[datasourceRow.id] = {}; + } + + const agentDatasource = getAgentDatasourceByDatasourceId(datasourceRow.id); + tables.forEach(tableName => { + const configuredColumns = resolveConfiguredColumns( + agentDatasource?.selectColumns, + tableName, + ); + selectedColumns.value[datasourceRow.id][tableName] = configuredColumns; + columnRestrictionEnabled.value[datasourceRow.id][tableName] = + configuredColumns.length > 0; + }); + + await Promise.all( + tables.map(tableName => loadColumnsForTable(datasourceRow.id!, tableName)), + ); + columnDialogVisible.value = true; + }; + + const selectAllColumnsForTable = (tableName: string) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + selectedColumns.value[datasourceId][tableName] = [ + ...(columnOptionsByDatasource.value[datasourceId]?.[tableName] || []), + ]; + }; + + const clearColumnsForTable = (tableName: string) => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + selectedColumns.value[datasourceId][tableName] = []; + }; + + const saveDatasourceColumns = async () => { + const datasourceId = currentColumnDatasource.value?.id; + if (!datasourceId) { + return; + } + + const invalidTables = currentColumnTables.value.filter(tableName => { + return ( + columnRestrictionEnabled.value[datasourceId]?.[tableName] && + !(selectedColumns.value[datasourceId]?.[tableName] || []).length + ); + }); + if (invalidTables.length > 0) { + ElMessage.warning(`请至少为以下数据表选择一个字段:${invalidTables.join('、')}`); + return; + } + + savingColumnVisibility.value = true; + try { + const tables = currentColumnTables.value + .filter(tableName => columnRestrictionEnabled.value[datasourceId]?.[tableName]) + .map(tableName => ({ + tableName, + columns: [...(selectedColumns.value[datasourceId]?.[tableName] || [])], + })); + + const response = await agentDatasourceService.updateDatasourceColumns( + String(props.agentId), + { + datasourceId, + tables, + }, + ); + + if (response.success && response.data) { + applyAgentDatasourceSnapshot(response.data); + ElMessage.success('字段可见性更新成功'); + columnDialogVisible.value = false; + } else { + ElMessage.error(response.message || '字段可见性更新失败'); + } + } catch (error) { + ElMessage.error(getErrorMessage(error, '字段可见性更新失败')); + console.error('Failed to update datasource columns:', error); + } finally { + savingColumnVisibility.value = false; + } + }; + // 全选表 const selectAllTables = (datasource: Datasource) => { if (!datasource.id || !tableLists.value[datasource.id]) return; @@ -1343,6 +1789,10 @@ return; } + if (datasourceRow.status !== 'active') { + ElMessage.warning('禁用的数据源无需配置逻辑关系,请先启用'); + return; + } currentForeignKeyDatasource.value = datasourceRow; foreignKeyDialogVisible.value = true; @@ -1590,6 +2040,14 @@ editingDatasource, tableLists, selectedTables, + selectedColumns, + columnOptionsByDatasource, + columnRestrictionEnabled, + columnLoadingStates, + columnDialogVisible, + currentColumnDatasource, + currentColumnTables, + savingColumnVisibility, tableLoadingStates, updateLoadingStates, initAgentDatasource, @@ -1605,6 +2063,12 @@ deleteDatasource, loadDatasourceTables, updateDatasourceTables, + openColumnVisibilityDialog, + saveDatasourceColumns, + selectAllColumnsForTable, + clearColumnsForTable, + toggleColumnRestriction, + getColumnLoadingKey, selectAllTables, clearAllTables, truncateText, diff --git a/data-agent-frontend/src/components/agent/SemanticsConfig.vue b/data-agent-frontend/src/components/agent/SemanticsConfig.vue index 0cd57f32e..8a859eda8 100644 --- a/data-agent-frontend/src/components/agent/SemanticsConfig.vue +++ b/data-agent-frontend/src/components/agent/SemanticsConfig.vue @@ -15,7 +15,6 @@ --> - - diff --git a/data-agent-frontend/src/components/run/ResultSetDisplay.vue b/data-agent-frontend/src/components/run/ResultSetDisplay.vue index 9b82a9275..c5ebc1653 100644 --- a/data-agent-frontend/src/components/run/ResultSetDisplay.vue +++ b/data-agent-frontend/src/components/run/ResultSetDisplay.vue @@ -230,7 +230,7 @@ display: flex; justify-content: space-between; align-items: center; - //margin-left: 15px; + /* margin-left: 15px; */ margin-bottom: 12px; padding: 8px 0; border-bottom: 1px solid #ebeef5; diff --git a/data-agent-frontend/src/services/agentDatasource.ts b/data-agent-frontend/src/services/agentDatasource.ts index bf49b3624..2137c9fed 100644 --- a/data-agent-frontend/src/services/agentDatasource.ts +++ b/data-agent-frontend/src/services/agentDatasource.ts @@ -28,26 +28,44 @@ interface UpdateDatasourceTablesDto { tables?: string[]; } +interface TableColumnsSelectionDto { + tableName: string; + columns?: string[]; +} + +interface UpdateDatasourceColumnsDto { + datasourceId?: number; + tables?: TableColumnsSelectionDto[]; +} + const BASE_URL_FUNC = (agentId: string) => `/api/agent/${agentId}/datasources`; +const extractApiErrorMessage = (error: unknown, fallback: string): string => { + if (axios.isAxiosError(error)) { + const responseMessage = error.response?.data?.message; + if (typeof responseMessage === 'string' && responseMessage.trim()) { + return responseMessage; + } + if (typeof error.message === 'string' && error.message.trim()) { + return error.message; + } + } + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; +}; + class AgentDatasourceService { - /** - * 初始化数据源Schema - * @param agentId 智能体ID - */ async initSchema(agentId: string): Promise> { try { const response = await axios.post>(`${BASE_URL_FUNC(agentId)}/init`); return response.data; } catch (error) { - throw new Error(`初始化Schema失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '初始化 Schema 失败')); } } - /** - * 获取智能体的数据源列表 - * @param agentId 智能体ID - */ async getAgentDatasource(agentId: number): Promise { try { const response = await axios.get>( @@ -58,36 +76,24 @@ class AgentDatasourceService { } throw new Error(response.data.message); } catch (error) { - throw new Error(`获取数据源列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '获取数据源列表失败')); } } - /** - * 获取当前激活的智能体 - * @param agentId 智能体ID - */ async getActiveAgentDatasource(agentId: number): Promise { try { const response = await axios.get>( - BASE_URL_FUNC(String(agentId)) + '/active', + `${BASE_URL_FUNC(String(agentId))}/active`, ); - if (response.data.success) { - if (response.data.data === undefined) { - throw new Error('后端错误'); - } + if (response.data.success && response.data.data) { return response.data.data; } - throw new Error(response.data.message); + throw new Error(response.data.message || '后端返回了空数据'); } catch (error) { - throw new Error(`获取数据源列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '获取当前启用数据源失败')); } } - /** - * 为智能体添加数据源 - * @param agentId 智能体ID - * @param datasourceId 数据源ID - */ async addDatasourceToAgent( agentId: string, datasourceId: number, @@ -98,15 +104,10 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`添加数据源失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '添加数据源失败')); } } - /** - * 从智能体移除数据源 - * @param agentId 智能体ID - * @param datasourceId 数据源ID - */ async removeDatasourceFromAgent( agentId: string, datasourceId: number, @@ -117,15 +118,10 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`移除数据源失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '移除数据源失败')); } } - /** - * 启用/禁用智能体的数据源 - * @param agentId 智能体ID - * @param dto 切换参数 - */ async toggleDatasourceForAgent( agentId: string, dto: ToggleDatasourceDto, @@ -137,24 +133,55 @@ class AgentDatasourceService { ); return response.data; } catch (error) { - throw new Error(`切换数据源状态失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '切换数据源状态失败')); } } - /** - * 更新数据源的表列表 - * @param agentId 智能体ID - * @param dto 更新参数 - */ async updateDatasourceTables( agentId: string, dto: UpdateDatasourceTablesDto, - ): Promise> { + ): Promise> { + try { + const response = await axios.post>( + `${BASE_URL_FUNC(agentId)}/tables`, + dto, + ); + return response.data; + } catch (error) { + throw new Error(extractApiErrorMessage(error, '更新数据表列表失败')); + } + } + + async updateDatasourceColumns( + agentId: string, + dto: UpdateDatasourceColumnsDto, + ): Promise> { try { - const response = await axios.post>(`${BASE_URL_FUNC(agentId)}/tables`, dto); + const response = await axios.post>( + `${BASE_URL_FUNC(agentId)}/columns`, + dto, + ); return response.data; } catch (error) { - throw new Error(`更新数据源表列表失败: ${error}`); + throw new Error(extractApiErrorMessage(error, '更新字段可见性失败')); + } + } + + async getVisibleTableColumns( + agentId: string, + datasourceId: number, + tableName: string, + ): Promise { + try { + const response = await axios.get>( + `${BASE_URL_FUNC(agentId)}/${datasourceId}/tables/${encodeURIComponent(tableName)}/columns`, + ); + if (response.data.success) { + return response.data.data || []; + } + throw new Error(response.data.message); + } catch (error) { + throw new Error(extractApiErrorMessage(error, `加载表 ${tableName} 的字段失败`)); } } } diff --git a/data-agent-frontend/src/services/chat.ts b/data-agent-frontend/src/services/chat.ts index 66317b88a..b360a1845 100644 --- a/data-agent-frontend/src/services/chat.ts +++ b/data-agent-frontend/src/services/chat.ts @@ -39,15 +39,104 @@ export interface ChatMessage { titleNeeded?: boolean; } +export interface AnswerTraceSemanticHit { + tableName?: string; + columnName?: string; + businessName?: string; + businessDescription?: string; + matchedBy?: string; + score?: number; + relationHint?: string; +} + +export interface AnswerTraceKnowledgeHit { + vectorType?: string; + knowledgeId?: string; + title?: string; + summary?: string; + snippet?: string; + source?: string; + concreteType?: string; +} + +export interface AnswerTraceToolStep { + toolName?: string; + title?: string; + summary?: string; + detail?: string; + datasource?: string; + timestampEpochMs?: number; +} + +export interface AnswerTraceExplain { + sessionId: string; + runtimeRequestId: string; + agentId?: string; + question?: string; + answer?: string; + datasource?: string; + sql?: string; + decisionReason?: string; + resultScope?: string; + usedTables: string[]; + usedColumns: string[]; + relationEvidence: Record[]; + toolDecisionReasons: string[]; + resultScopeDetails: string[]; + semanticHits: AnswerTraceSemanticHit[]; + knowledgeHits: AnswerTraceKnowledgeHit[]; + toolSteps: AnswerTraceToolStep[]; + clarify?: Record; + warnings: string[]; + updatedAt: number; +} + +export interface TraceSpan { + name: string; + spanId: string; + parentSpanId: string; + kind: string; + status: string; + startEpochMs: number; + endEpochMs: number; + durationMs: number; + attributes: Record; + children: TraceSpan[]; +} + +export interface SessionTrace { + sessionId: string; + traceId: string; + runtimeRequestId: string; + agentId: string; + startEpochMs: number; + endEpochMs: number; + durationMs: number; + spanCount: number; + rootSpan: TraceSpan | null; + rootSpans: TraceSpan[]; +} + const API_BASE_URL = '/api'; +const resolveAgentId = (agentId: number | string): number => { + const resolvedAgentId = typeof agentId === 'number' ? agentId : Number(agentId); + if (!Number.isFinite(resolvedAgentId)) { + throw new Error('智能体ID无效,请刷新后重试'); + } + return resolvedAgentId; +}; + class ChatService { /** * 获取Agent的会话列表 * @param agentId Agent ID */ async getAgentSessions(agentId: number): Promise { - const response = await axios.get(`${API_BASE_URL}/agent/${agentId}/sessions`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get( + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, + ); return response.data; } @@ -58,13 +147,14 @@ class ChatService { * @param userId 用户ID */ async createSession(agentId: number, title?: string, userId?: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); const request = { title, userId, }; const response = await axios.post( - `${API_BASE_URL}/agent/${agentId}/sessions`, + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, request, ); return response.data; @@ -75,7 +165,10 @@ class ChatService { * @param agentId Agent ID */ async clearAgentSessions(agentId: number): Promise { - const response = await axios.delete(`${API_BASE_URL}/agent/${agentId}/sessions`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.delete( + `${API_BASE_URL}/agent/${resolvedAgentId}/sessions`, + ); return response.data; } @@ -83,9 +176,41 @@ class ChatService { * 获取会话的消息列表 * @param sessionId 会话ID */ - async getSessionMessages(sessionId: string): Promise { + async getSessionMessages(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.get( `${API_BASE_URL}/sessions/${sessionId}/messages`, + { params: { agentId: resolvedAgentId } }, + ); + return response.data; + } + + async getSessionTrace(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get(`${API_BASE_URL}/sessions/${sessionId}/trace`, { + params: { agentId: resolvedAgentId }, + }); + return response.data; + } + + async getLatestAnswerExplain(sessionId: string, agentId: number): Promise { + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get( + `${API_BASE_URL}/sessions/${sessionId}/answers/latest/explain`, + { params: { agentId: resolvedAgentId } }, + ); + return response.data; + } + + async getAnswerExplain( + sessionId: string, + runtimeRequestId: string, + agentId: number, + ): Promise { + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.get( + `${API_BASE_URL}/sessions/${sessionId}/answers/${runtimeRequestId}/explain`, + { params: { agentId: resolvedAgentId } }, ); return response.data; } @@ -95,8 +220,13 @@ class ChatService { * @param sessionId 会话ID * @param message 消息对象 */ - async saveMessage(sessionId: string, message: ChatMessage): Promise { + async saveMessage( + sessionId: string, + agentId: number, + message: ChatMessage, + ): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); // 设置会话ID const messageData = { ...message, @@ -106,6 +236,7 @@ class ChatService { const response = await axios.post( `${API_BASE_URL}/sessions/${sessionId}/messages`, messageData, + { params: { agentId: resolvedAgentId } }, ); return response.data; } catch (error) { @@ -121,13 +252,14 @@ class ChatService { * @param sessionId 会话ID * @param isPinned 是否置顶 */ - async pinSession(sessionId: string, isPinned: boolean): Promise { + async pinSession(sessionId: string, agentId: number, isPinned: boolean): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.put( `${API_BASE_URL}/sessions/${sessionId}/pin`, null, { - params: { isPinned }, + params: { agentId: resolvedAgentId, isPinned }, }, ); return response.data; @@ -147,8 +279,9 @@ class ChatService { * @param sessionId 会话ID * @param title 新标题 */ - async renameSession(sessionId: string, title: string): Promise { + async renameSession(sessionId: string, agentId: number, title: string): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); if (!title || title.trim().length === 0) { throw new Error('标题不能为空'); } @@ -157,7 +290,7 @@ class ChatService { `${API_BASE_URL}/sessions/${sessionId}/rename`, null, { - params: { title: title.trim() }, + params: { agentId: resolvedAgentId, title: title.trim() }, }, ); return response.data; @@ -176,9 +309,12 @@ class ChatService { * 删除单个会话 * @param sessionId 会话ID */ - async deleteSession(sessionId: string): Promise { + async deleteSession(sessionId: string, agentId: number): Promise { try { - const response = await axios.delete(`${API_BASE_URL}/sessions/${sessionId}`); + const resolvedAgentId = resolveAgentId(agentId); + const response = await axios.delete(`${API_BASE_URL}/sessions/${sessionId}`, { + params: { agentId: resolvedAgentId }, + }); return response.data; } catch (error) { if (axios.isAxiosError(error) && error.response?.status === 500) { @@ -193,12 +329,14 @@ class ChatService { * @param sessionId 会话ID * @param content 报告内容 */ - async downloadHtmlReport(sessionId: string, content: string): Promise { + async downloadHtmlReport(sessionId: string, agentId: number, content: string): Promise { try { + const resolvedAgentId = resolveAgentId(agentId); const response = await axios.post( `${API_BASE_URL}/sessions/${sessionId}/reports/html`, content, { + params: { agentId: resolvedAgentId }, responseType: 'blob', // 重要:设置响应类型为blob headers: { 'Content-Type': 'text/plain;charset=utf-8', // 明确设置内容类型和编码 diff --git a/data-agent-frontend/src/services/datasource.ts b/data-agent-frontend/src/services/datasource.ts index 184858fd2..d0b9f0d6c 100644 --- a/data-agent-frontend/src/services/datasource.ts +++ b/data-agent-frontend/src/services/datasource.ts @@ -45,6 +45,7 @@ export interface AgentDatasource { updateTime?: string; datasource?: Datasource; selectTables?: string[]; + selectColumns?: Record; } // 定义数据源类型接口 @@ -97,6 +98,23 @@ class DatasourceService { } } + async getTableColumns(id: number, tableName: string): Promise { + try { + const response = await axios.get>( + `${API_BASE_URL}/${id}/tables/${encodeURIComponent(tableName)}/columns`, + ); + if (response.data.success) { + return response.data.data || []; + } + throw new Error(response.data.message); + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 400) { + return []; + } + throw error; + } + } + // 4. 创建数据源 async createDatasource(datasource: Datasource): Promise { const response = await axios.post(API_BASE_URL, datasource); diff --git a/data-agent-frontend/src/services/graph.ts b/data-agent-frontend/src/services/graph.ts index 42b529b40..377903342 100644 --- a/data-agent-frontend/src/services/graph.ts +++ b/data-agent-frontend/src/services/graph.ts @@ -14,22 +14,34 @@ * limitations under the License. */ -export interface GraphRequest { +export interface AgentRequest { agentId: string; threadId?: string; + runtimeRequestId?: string; query: string; - humanFeedback: boolean; + clarifyCheckEnabled?: boolean; + humanFeedback?: boolean; humanFeedbackContent?: string; rejectedPlan: boolean; - nl2sqlOnly: boolean; } -export interface GraphNodeResponse { +export interface ClarifyMetadata { + clarifyRequired?: boolean; + riskLevel?: string; + originalQuery?: string; + missingDimensions?: string[]; + followUpQuestions?: string[]; + suggestedAssumptions?: string[]; + summary?: string; +} + +export interface AgentResponse { agentId: string; threadId: string; nodeName: string; textType: TextType; text: string; + metadata?: ClarifyMetadata & Record; error: boolean; complete: boolean; } @@ -56,8 +68,8 @@ class GraphService { * @returns 关闭连接的函数 */ async streamSearch( - request: GraphRequest, - onMessage: (response: GraphNodeResponse) => Promise, + request: AgentRequest, + onMessage: (response: AgentResponse) => Promise, onError?: (error: Error) => Promise, onComplete?: () => Promise, ): Promise<() => void> { @@ -67,14 +79,18 @@ class GraphService { if (request.threadId) { params.append('threadId', request.threadId); } + if (request.runtimeRequestId) { + params.append('runtimeRequestId', request.runtimeRequestId); + } params.append('query', request.query); - params.append('humanFeedback', request.humanFeedback.toString()); - params.append('rejectedPlan', request.rejectedPlan.toString()); - params.append('nl2sqlOnly', request.nl2sqlOnly.toString()); - + params.append('clarifyCheckEnabled', String(Boolean(request.clarifyCheckEnabled))); + if (request.humanFeedback) { + params.append('humanFeedback', request.humanFeedback.toString()); + } if (request.humanFeedbackContent) { params.append('humanFeedbackContent', request.humanFeedbackContent); } + params.append('rejectedPlan', request.rejectedPlan.toString()); const url = `${API_BASE_URL}/stream/search?${params.toString()}`; @@ -95,7 +111,7 @@ class GraphService { return; } try { - const nodeResponse: GraphNodeResponse = JSON.parse(event.data); + const nodeResponse: AgentResponse = JSON.parse(event.data); console.log( `Node: ${nodeResponse.nodeName}, message: ${nodeResponse.text}, type: ${nodeResponse.textType}`, ); diff --git a/data-agent-frontend/src/services/resultSet.ts b/data-agent-frontend/src/services/resultSet.ts index 461833ecd..3e13d6102 100644 --- a/data-agent-frontend/src/services/resultSet.ts +++ b/data-agent-frontend/src/services/resultSet.ts @@ -41,11 +41,3 @@ export interface PaginationConfig { pageSize: number; total: number; } - -/** - * 结果集显示配置 - */ -export interface ResultSetDisplayConfig { - showSqlResults: boolean; - pageSize: number; -} diff --git a/data-agent-frontend/src/services/semanticModel.ts b/data-agent-frontend/src/services/semanticModel.ts index a19699d4a..4776daaf3 100644 --- a/data-agent-frontend/src/services/semanticModel.ts +++ b/data-agent-frontend/src/services/semanticModel.ts @@ -35,6 +35,7 @@ interface SemanticModel { interface SemanticModelAddDto { agentId: number; + datasourceId: number; tableName: string; columnName: string; businessName: string; @@ -44,6 +45,14 @@ interface SemanticModelAddDto { dataType: string; } +interface SemanticModelUpdateDto { + businessName: string; + synonyms: string; + businessDescription: string; + columnComment: string; + dataType: string; +} + interface SemanticModelImportItem { tableName: string; columnName: string; @@ -56,6 +65,7 @@ interface SemanticModelImportItem { interface SemanticModelBatchImportDTO { agentId: number; + datasourceId: number; items: SemanticModelImportItem[]; } @@ -68,6 +78,22 @@ interface BatchImportResult { const API_BASE_URL = '/api/semantic-model'; +const extractApiErrorMessage = (error: unknown, fallback: string): string => { + if (axios.isAxiosError(error)) { + const responseMessage = error.response?.data?.message; + if (typeof responseMessage === 'string' && responseMessage.trim()) { + return responseMessage; + } + if (typeof error.message === 'string' && error.message.trim()) { + return error.message; + } + } + if (error instanceof Error && error.message.trim()) { + return error.message; + } + return fallback; +}; + class SemanticModelService { /** * 获取语义模型列表 @@ -104,8 +130,15 @@ class SemanticModelService { * @param model 语义模型 DTO 对象 */ async create(model: SemanticModelAddDto): Promise { - const response = await axios.post(API_BASE_URL, model); - return response.data.success; + try { + const response = await axios.post(API_BASE_URL, model); + if (response.data.success) { + return true; + } + throw new Error(response.data.message || '创建失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, '创建失败')); + } } /** @@ -113,15 +146,18 @@ class SemanticModelService { * @param id 语义模型 ID * @param model 语义模型对象 */ - async update(id: number, model: SemanticModel): Promise { + async update(id: number, model: SemanticModelUpdateDto): Promise { try { const response = await axios.put(`${API_BASE_URL}/${id}`, model); - return response.data.success; + if (response.data.success) { + return true; + } + throw new Error(response.data.message || '更新失败'); } catch (error) { if (axios.isAxiosError(error) && error.response?.status === 404) { return false; } - throw error; + throw new Error(extractApiErrorMessage(error, '更新失败')); } } @@ -175,11 +211,18 @@ class SemanticModelService { * @param dto 批量导入DTO */ async batchImport(dto: SemanticModelBatchImportDTO): Promise { - const response = await axios.post>( - `${API_BASE_URL}/batch-import`, - dto, - ); - return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + try { + const response = await axios.post>( + `${API_BASE_URL}/batch-import`, + dto, + ); + if (response.data.success) { + return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + } + throw new Error(response.data.message || '批量导入失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, '批量导入失败')); + } } /** @@ -187,21 +230,29 @@ class SemanticModelService { * @param file Excel文件 * @param agentId 智能体ID */ - async importExcel(file: File, agentId: number): Promise { + async importExcel(file: File, agentId: number, datasourceId: number): Promise { const formData = new FormData(); formData.append('file', file); formData.append('agentId', agentId.toString()); + formData.append('datasourceId', datasourceId.toString()); - const response = await axios.post>( - `${API_BASE_URL}/import/excel`, - formData, - { - headers: { - 'Content-Type': 'multipart/form-data', + try { + const response = await axios.post>( + `${API_BASE_URL}/import/excel`, + formData, + { + headers: { + 'Content-Type': 'multipart/form-data', + }, }, - }, - ); - return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + ); + if (response.data.success) { + return response.data.data || { total: 0, successCount: 0, failCount: 0, errors: [] }; + } + throw new Error(response.data.message || 'Excel导入失败'); + } catch (error) { + throw new Error(extractApiErrorMessage(error, 'Excel导入失败')); + } } /** @@ -228,6 +279,7 @@ export default new SemanticModelService(); export type { SemanticModel, SemanticModelAddDto, + SemanticModelUpdateDto, SemanticModelImportItem, SemanticModelBatchImportDTO, BatchImportResult, diff --git a/data-agent-frontend/src/services/sessionStateManager.ts b/data-agent-frontend/src/services/sessionStateManager.ts index 6accac334..9584a0d9d 100644 --- a/data-agent-frontend/src/services/sessionStateManager.ts +++ b/data-agent-frontend/src/services/sessionStateManager.ts @@ -15,17 +15,110 @@ */ import { ref, Ref } from 'vue'; -import { GraphNodeResponse, GraphRequest } from '@/services/graph.ts'; +import { AgentResponse, AgentRequest } from '@/services/graph.ts'; +import { AnswerTraceExplain } from '@/services/chat.ts'; + +export interface PendingClarifyState { + originalQuery: string; + riskLevel: string; + summary?: string; + missingDimensions: string[]; + followUpQuestions: string[]; + suggestedAssumptions: string[]; +} export interface SessionRuntimeState { isStreaming: boolean; - nodeBlocks: GraphNodeResponse[][]; + nodeBlocks: AgentResponse[][]; persistedBlockCount: number; closeStream: (() => void) | null; - lastRequest: GraphRequest | null; + lastRequest: AgentRequest | null; + pendingClarify: PendingClarifyState | null; htmlReportContent: string; htmlReportSize: number; markdownReportContent: string; + answerExplain: AnswerTraceExplain | null; + answerExplainVisible: boolean; +} + +// 可持久化的状态字段(不包括函数和临时状态) +interface PersistableState { + nodeBlocks: AgentResponse[][]; + persistedBlockCount: number; + lastRequest: AgentRequest | null; + pendingClarify: PendingClarifyState | null; + htmlReportContent: string; + htmlReportSize: number; + markdownReportContent: string; + answerExplain: AnswerTraceExplain | null; + answerExplainVisible: boolean; +} + +const STORAGE_KEY_PREFIX = 'session_state_'; +const MAX_STORAGE_SIZE_MB = 4; // 最大存储 4MB +const MAX_NODE_BLOCKS = 10; // 最多保存 10 个 nodeBlocks + +/** + * 从 sessionStorage 加载状态 + */ +function loadStateFromStorage(sessionId: string): Partial | null { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + const stored = sessionStorage.getItem(key); + if (stored) { + return JSON.parse(stored); + } + } catch (error) { + console.error('加载会话状态失败:', error); + } + return null; +} + +/** + * 保存状态到 sessionStorage(带大小限制) + */ +function saveStateToStorage(sessionId: string, state: SessionRuntimeState) { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + + // 只保存最近的 nodeBlocks,避免数据过大 + const persistable: PersistableState = { + nodeBlocks: state.nodeBlocks.slice(-MAX_NODE_BLOCKS), + persistedBlockCount: state.persistedBlockCount, + lastRequest: state.lastRequest, + pendingClarify: state.pendingClarify, + htmlReportContent: state.htmlReportContent, + htmlReportSize: state.htmlReportSize, + markdownReportContent: state.markdownReportContent, + answerExplain: state.answerExplain, + answerExplainVisible: state.answerExplainVisible, + }; + + const json = JSON.stringify(persistable); + const sizeInMB = new Blob([json]).size / (1024 * 1024); + + // 检查大小限制 + if (sizeInMB > MAX_STORAGE_SIZE_MB) { + console.warn(`会话状态过大 (${sizeInMB.toFixed(2)}MB),跳过保存`); + return; + } + + sessionStorage.setItem(key, json); + } catch (error) { + console.error('保存会话状态失败:', error); + } +} + +/** + * 从 sessionStorage 删除状态 + */ +function removeStateFromStorage(sessionId: string) { + try { + const key = STORAGE_KEY_PREFIX + sessionId; + sessionStorage.removeItem(key); + } catch (error) { + console.error('删除会话状态失败:', error); + } } /** @@ -40,15 +133,21 @@ export function useSessionStateManager() { */ const getSessionState = (sessionId: string): SessionRuntimeState => { if (!sessionStates.value.has(sessionId)) { + // 尝试从 sessionStorage 加载 + const stored = loadStateFromStorage(sessionId); + sessionStates.value.set(sessionId, { isStreaming: false, - nodeBlocks: [], - persistedBlockCount: 0, + nodeBlocks: stored?.nodeBlocks ?? [], + persistedBlockCount: stored?.persistedBlockCount ?? 0, closeStream: null, - lastRequest: null, - htmlReportContent: '', - htmlReportSize: 0, - markdownReportContent: '', + lastRequest: stored?.lastRequest ?? null, + pendingClarify: stored?.pendingClarify ?? null, + htmlReportContent: stored?.htmlReportContent ?? '', + htmlReportSize: stored?.htmlReportSize ?? 0, + markdownReportContent: stored?.markdownReportContent ?? '', + answerExplain: stored?.answerExplain ?? null, + answerExplainVisible: stored?.answerExplainVisible ?? false, }); } return sessionStates.value.get(sessionId)!; @@ -61,12 +160,24 @@ export function useSessionStateManager() { sessionId: string, viewState: { isStreaming: Ref; - nodeBlocks: Ref; + nodeBlocks: Ref; + answerExplain?: Ref; + answerExplainVisible?: Ref; + pendingClarify?: Ref; }, ) => { const state = getSessionState(sessionId); viewState.isStreaming.value = state.isStreaming; viewState.nodeBlocks.value = state.nodeBlocks; + if (viewState.answerExplain) { + viewState.answerExplain.value = state.answerExplain; + } + if (viewState.answerExplainVisible) { + viewState.answerExplainVisible.value = state.answerExplainVisible; + } + if (viewState.pendingClarify) { + viewState.pendingClarify.value = state.pendingClarify; + } }; /** @@ -76,12 +187,27 @@ export function useSessionStateManager() { sessionId: string, viewState: { isStreaming: Ref; - nodeBlocks: Ref; + nodeBlocks: Ref; + answerExplain?: Ref; + answerExplainVisible?: Ref; + pendingClarify?: Ref; }, ) => { const state = getSessionState(sessionId); state.isStreaming = viewState.isStreaming.value; state.nodeBlocks = viewState.nodeBlocks.value; + if (viewState.answerExplain) { + state.answerExplain = viewState.answerExplain.value; + } + if (viewState.answerExplainVisible) { + state.answerExplainVisible = viewState.answerExplainVisible.value; + } + if (viewState.pendingClarify) { + state.pendingClarify = viewState.pendingClarify.value; + } + + // 保存到 sessionStorage(带大小限制) + saveStateToStorage(sessionId, state); }; /** @@ -93,6 +219,8 @@ export function useSessionStateManager() { state.closeStream(); } sessionStates.value.delete(sessionId); + // 同时删除 sessionStorage 中的数据 + removeStateFromStorage(sessionId); }; /** diff --git a/data-agent-frontend/src/views/AgentRun.vue b/data-agent-frontend/src/views/AgentRun.vue index 02d66893d..04d7864d2 100644 --- a/data-agent-frontend/src/views/AgentRun.vue +++ b/data-agent-frontend/src/views/AgentRun.vue @@ -55,13 +55,15 @@ :class="message.messageType === 'text' ? ['message-container', message.role] : ''" > -
+
+
+
@@ -184,13 +186,6 @@ - - -
@@ -219,48 +214,18 @@ :onQuestionClick="handlePresetQuestionClick" />
-
- 人工反馈 - - - -
-
- 仅NL2SQL +
+ 开始澄清校验
-
- 自动Scroll - -
-
- 显示SQL结果 - - - -
每页数量 @@ -270,6 +235,16 @@
+
+ + 查看数据来源 + +
- - org.springframework.ai - spring-ai-transformers - true - - - - io.micrometer - micrometer-observation-test - test - - org.junit.jupiter junit-jupiter-api @@ -179,21 +154,6 @@ ${springdoc-openapi.version} - - org.springframework.boot - spring-boot-testcontainers - test - - - org.testcontainers - testcontainers - test - - - org.testcontainers - junit-jupiter - test - io.opentelemetry @@ -208,30 +168,11 @@ opentelemetry-exporter-otlp - io.opentelemetry - opentelemetry-sdk-extension-autoconfigure - - - - - org.testcontainers - mysql - test - - - - - io.projectreactor - reactor-test - test + io.opentelemetry.instrumentation + opentelemetry-reactor-3.1 + ${opentelemetry-instrumentation.version} - - - org.awaitility - awaitility - test - io.netty @@ -275,18 +216,6 @@ spring-ai-alibaba-dashscope ${spring-ai-alibaba.version} - - com.github.victools - jsonschema-generator - compile - - - - com.github.victools - jsonschema-module-jackson - compile - - org.springframework.boot spring-boot-starter-webflux @@ -298,11 +227,6 @@ lombok - - com.atlassian.commonmark - commonmark - - org.springframework.boot spring-boot-starter-jdbc @@ -330,10 +254,6 @@ ${elasticsearch-client.version} - - jakarta.validation - jakarta.validation-api - org.springframework.ai spring-ai-tika-document-reader diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java similarity index 94% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java index 4027c9ce6..ad2b2048d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/GraphRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/dto/AgentRequest.java @@ -24,7 +24,7 @@ @AllArgsConstructor @NoArgsConstructor @Builder -public class GraphRequest { +public class AgentRequest { private String agentId; @@ -34,12 +34,12 @@ public class GraphRequest { private String query; + private boolean clarifyCheckEnabled; + private boolean humanFeedback; private String humanFeedbackContent; private boolean rejectedPlan; - private boolean nl2sqlOnly; - } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java index 5d9fa2332..2d185b905 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeEventPublisher.java @@ -15,11 +15,11 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; @FunctionalInterface public interface AgentRuntimeEventPublisher { - void publish(GraphNodeResponse response); + void publish(AgentResponse response); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java index 51315442a..0c1810178 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeExtensionFactory.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRuntimeExtensions; import io.agentscope.core.hook.Hook; import io.agentscope.core.memory.Memory; @@ -42,13 +42,14 @@ public class AgentRuntimeExtensionFactory { private final AgentScopeSkillBoxFactory skillBoxFactory; - public AgentRuntimeExtensions create(GraphRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher, - Map toolCallbacks) { + public AgentRuntimeExtensions create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher, + Map toolCallbacks, PreparedMemory preparedMemory) { Toolkit toolkit = toolkitFactory.buildToolkit(toolCallbacks); SkillBox skillBox = skillBoxFactory.create(request.getAgentId(), toolkit); - Memory memory = memoryFactory.create(request.getThreadId()); + Memory memory = preparedMemory == null ? memoryFactory.create(request).memory() : preparedMemory.memory(); AgentRuntimeRequestMetadata requestMetadata = new AgentRuntimeRequestMetadata(request.getAgentId(), - request.getThreadId(), request.isNl2sqlOnly()); + request.getThreadId(), request.getRuntimeRequestId(), request.isHumanFeedback(), + request.getHumanFeedbackContent()); ToolExecutionContext toolExecutionContext = ToolExecutionContext.builder() .register(requestMetadata) .register("graphRequest", request) @@ -56,6 +57,7 @@ public AgentRuntimeExtensions create(GraphRequest request, @Nullable AgentRuntim List hooks = hookFactory.create(request, eventPublisher); Map attributes = new HashMap<>(); attributes.put("threadId", request.getThreadId()); + attributes.put("memoryLoadedFromNative", preparedMemory != null && preparedMemory.loadedFromNative()); return new AgentRuntimeExtensions(toolkit, memory, toolExecutionContext, hooks, attributes, skillBox, ""); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java index a28c4c8c8..326300b35 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentRuntimeRequestMetadata.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -public record AgentRuntimeRequestMetadata(String agentId, String threadId, boolean nl2sqlOnly) { +public record AgentRuntimeRequestMetadata(String agentId, String threadId, String runtimeRequestId, + boolean humanFeedback, String humanFeedbackContent) { } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java index 2e733b251..15ea436db 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeHookFactory.java @@ -15,10 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import io.agentscope.core.hook.Hook; import java.util.ArrayList; import java.util.List; @@ -30,20 +27,11 @@ @RequiredArgsConstructor public class AgentScopeHookFactory { - private final ChatSessionService chatSessionService; - - private final ChatMessageService chatMessageService; - - private final AgentSessionRegistry sessionRegistry; - - public List create(GraphRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher) { + public List create(AgentRequest request, @Nullable AgentRuntimeEventPublisher eventPublisher) { List hooks = new ArrayList<>(); if (eventPublisher != null) { - hooks.add(new AgentScopeStreamingHook(request.getAgentId(), request.getThreadId(), request.isNl2sqlOnly(), - eventPublisher)); + hooks.add(new AgentScopeStreamingHook(request.getAgentId(), request.getThreadId(), eventPublisher)); } - hooks.add(new AgentScopeMemoryPersistenceHook(request.getThreadId(), request.getRuntimeRequestId(), - sessionRegistry, chatSessionService, chatMessageService)); HumanFeedbackHook humanFeedbackHook = HumanFeedbackHook.from(request); if (humanFeedbackHook != null) { hooks.add(humanFeedbackHook); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java index 746849914..40fd939ea 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryFactory.java @@ -15,13 +15,9 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; import io.agentscope.core.memory.InMemoryMemory; -import io.agentscope.core.memory.Memory; -import io.agentscope.core.message.Msg; -import io.agentscope.core.message.MsgRole; -import java.util.List; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -32,46 +28,21 @@ @RequiredArgsConstructor public class AgentScopeMemoryFactory { - private static final int MEMORY_MESSAGE_LIMIT = 20; + private final AgentScopeNativeSessionService nativeSessionService; - private final ChatMessageService chatMessageService; - - public Memory create(String threadId) { + public PreparedMemory create(AgentRequest request) { InMemoryMemory memory = new InMemoryMemory(); + String threadId = request == null ? null : request.getThreadId(); if (!StringUtils.hasText(threadId)) { - return memory; - } - List history = chatMessageService.findRecentBySessionId(threadId, MEMORY_MESSAGE_LIMIT); - history.stream().map(this::toMessage).filter(msg -> msg != null).forEach(memory::addMessage); - log.debug("Loaded {} history messages into AgentScope memory, threadId={}", history.size(), threadId); - return memory; - } - - private Msg toMessage(ChatMessage message) { - if (message == null || !StringUtils.hasText(message.getContent())) { - return null; + return new PreparedMemory(memory, false); } - return Msg.builder() - .name(resolveName(message.getRole())) - .role(resolveRole(message.getRole())) - .textContent(message.getContent()) - .build(); - } - - private String resolveName(String role) { - return StringUtils.hasText(role) ? role : "user"; - } - - private MsgRole resolveRole(String role) { - if (!StringUtils.hasText(role)) { - return MsgRole.USER; + boolean loadedFromNative = nativeSessionService.loadMemoryIfExists(memory, threadId); + if (loadedFromNative) { + log.debug("Loaded AgentScope native session memory, threadId={}", threadId); + return new PreparedMemory(memory, true); } - return switch (role.trim().toLowerCase()) { - case "assistant" -> MsgRole.ASSISTANT; - case "system" -> MsgRole.SYSTEM; - case "tool" -> MsgRole.TOOL; - default -> MsgRole.USER; - }; + log.debug("No AgentScope native session memory found, threadId={}", threadId); + return new PreparedMemory(memory, false); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java deleted file mode 100644 index 4c238dc30..000000000 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeMemoryPersistenceHook.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.dataagent.agentscope.runtime; - -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; -import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; -import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; -import io.agentscope.core.hook.Hook; -import io.agentscope.core.hook.HookEvent; -import io.agentscope.core.hook.PostCallEvent; -import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -public class AgentScopeMemoryPersistenceHook implements Hook { - - private static final String MEMORY_MESSAGE_TYPE = "memory-text"; - - private static final String MEMORY_METADATA = "{\"source\":\"agentscope\",\"visibility\":\"memory-only\"}"; - - private final String threadId; - - private final String runtimeRequestId; - - private final AgentSessionRegistry sessionRegistry; - - private final ChatSessionService chatSessionService; - - private final ChatMessageService chatMessageService; - - public AgentScopeMemoryPersistenceHook(String threadId, String runtimeRequestId, - AgentSessionRegistry sessionRegistry, ChatSessionService chatSessionService, - ChatMessageService chatMessageService) { - this.threadId = threadId; - this.runtimeRequestId = runtimeRequestId; - this.sessionRegistry = sessionRegistry; - this.chatSessionService = chatSessionService; - this.chatMessageService = chatMessageService; - } - - @Override - public int priority() { - return 200; - } - - @Override - public Mono onEvent(T event) { - if (event instanceof PostCallEvent postCallEvent) { - persistFinalMessage(postCallEvent); - } - return Mono.just(event); - } - - private void persistFinalMessage(PostCallEvent event) { - if (!StringUtils.hasText(threadId) || chatSessionService.findBySessionId(threadId) == null) { - return; - } - if (sessionRegistry.isCancelled(threadId, runtimeRequestId)) { - return; - } - String finalText = event.getFinalMessage() == null ? null : event.getFinalMessage().getTextContent(); - if (!StringUtils.hasText(finalText)) { - return; - } - chatMessageService.saveMessage(ChatMessage.builder() - .sessionId(threadId) - .role("assistant") - .content(finalText) - .messageType(MEMORY_MESSAGE_TYPE) - .metadata(MEMORY_METADATA) - .build()); - } - -} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java index 123ff1438..d28212af2 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/AgentScopeStreamingHook.java @@ -15,7 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import com.alibaba.cloud.ai.dataagent.enums.TextType; import io.agentscope.core.hook.ActingChunkEvent; import io.agentscope.core.hook.Hook; @@ -34,28 +34,22 @@ public class AgentScopeStreamingHook implements Hook { private static final String PLANNER_REASONING_NODE = "planner-reasoning"; - private static final String SQL_GENERATOR_REASONING_NODE = "sql-generator-reasoning"; - private final String agentId; private final String threadId; - private final boolean nl2sqlOnly; - private final AgentRuntimeEventPublisher eventPublisher; - public AgentScopeStreamingHook(String agentId, String threadId, boolean nl2sqlOnly, - AgentRuntimeEventPublisher eventPublisher) { + public AgentScopeStreamingHook(String agentId, String threadId, AgentRuntimeEventPublisher eventPublisher) { this.agentId = agentId; this.threadId = threadId; - this.nl2sqlOnly = nl2sqlOnly; this.eventPublisher = eventPublisher; } @Override public Mono onEvent(T event) { if (event instanceof ReasoningChunkEvent reasoningChunkEvent) { - emit(resolveReasoningNodeName(), reasoningChunkEvent.getIncrementalChunk().getTextContent()); + emit(PLANNER_REASONING_NODE, reasoningChunkEvent.getIncrementalChunk().getTextContent()); } else if (event instanceof PreActingEvent preActingEvent) { emit(resolveToolNodeName(preActingEvent.getToolUse().getName()), @@ -76,7 +70,7 @@ private void emit(String nodeName, String text) { if (text == null || text.isBlank()) { return; } - eventPublisher.publish(GraphNodeResponse.builder() + eventPublisher.publish(AgentResponse.builder() .agentId(agentId) .threadId(threadId) .nodeName(nodeName) @@ -85,10 +79,6 @@ private void emit(String nodeName, String text) { .build()); } - private String resolveReasoningNodeName() { - return nl2sqlOnly ? SQL_GENERATOR_REASONING_NODE : PLANNER_REASONING_NODE; - } - private String resolveToolNodeName(String toolName) { return toolName == null || toolName.isBlank() ? "AgentScopeTool" : "tool:" + toolName; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java index c69396e15..c3d7eeb55 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/HumanFeedbackHook.java @@ -15,12 +15,14 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import io.agentscope.core.hook.Hook; import io.agentscope.core.hook.HookEvent; +import io.agentscope.core.hook.PreReasoningEvent; import io.agentscope.core.hook.PostReasoningEvent; import io.agentscope.core.message.Msg; import io.agentscope.core.message.MsgRole; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import org.springframework.util.StringUtils; @@ -40,10 +42,7 @@ private HumanFeedbackHook(boolean pauseAfterPlanning, boolean replayRequested, S this.feedbackDirective = feedbackDirective; } - public static HumanFeedbackHook from(GraphRequest request) { - if (request.isNl2sqlOnly()) { - return null; - } + public static HumanFeedbackHook from(AgentRequest request) { boolean hasFeedbackContent = StringUtils.hasText(request.getHumanFeedbackContent()); boolean requiresReplay = hasFeedbackContent || request.isRejectedPlan(); boolean requiresPause = request.isHumanFeedback() && !requiresReplay; @@ -55,27 +54,32 @@ public static HumanFeedbackHook from(GraphRequest request) { @Override public Mono onEvent(T event) { + if (event instanceof PreReasoningEvent preReasoningEvent && replayRequested.compareAndSet(true, false) + && StringUtils.hasText(feedbackDirective)) { + List messages = new ArrayList<>(preReasoningEvent.getInputMessages()); + messages.add(0, + Msg.builder().name("human-review").role(MsgRole.SYSTEM).textContent(feedbackDirective).build()); + preReasoningEvent.setInputMessages(messages); + return Mono.just(event); + } if (!(event instanceof PostReasoningEvent postReasoningEvent)) { return Mono.just(event); } if (pauseAfterPlanning) { postReasoningEvent.stopAgent(); - return Mono.just(event); - } - if (replayRequested.compareAndSet(true, false)) { - postReasoningEvent.gotoReasoning(List - .of(Msg.builder().name("human-review").role(MsgRole.SYSTEM).textContent(feedbackDirective).build())); } return Mono.just(event); } - private static String buildDirective(GraphRequest request) { + private static String buildDirective(AgentRequest request) { StringBuilder builder = new StringBuilder("Human review directive:"); if (request.isRejectedPlan()) { builder.append("\n- The previous plan was rejected. Re-plan before continuing."); } if (StringUtils.hasText(request.getHumanFeedbackContent())) { builder.append("\n- Feedback: ").append(request.getHumanFeedbackContent()); + builder.append( + "\n- Treat the feedback as authoritative clarification or explicit assumptions before any tool call."); } return builder.toString(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java new file mode 100644 index 000000000..ebf29b203 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/PreparedMemory.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.runtime; + +import io.agentscope.core.memory.InMemoryMemory; + +public record PreparedMemory(InMemoryMemory memory, boolean loadedFromNative) { + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/QueryClarifyService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/QueryClarifyService.java new file mode 100644 index 000000000..a01551fd1 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/QueryClarifyService.java @@ -0,0 +1,241 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.runtime; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.regex.Pattern; +import org.springframework.lang.Nullable; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Component +public class QueryClarifyService { + + private static final Pattern EXPLICIT_TIME_PATTERN = Pattern.compile( + "(今天|今日|昨天|昨日|本周|上周|本月|上月|本季度|上季度|今年|去年|最近\\s*\\d+\\s*(天|周|个月|月|年)|\\d{4}年(?:\\d{1,2}月(?:\\d{1,2}日)?)?|\\d{4}[-/]\\d{1,2}(?:[-/]\\d{1,2})?|Q[1-4]|第[一二三四1-4]季度|同期)"); + + private static final Pattern TIME_SENSITIVE_PATTERN = Pattern + .compile("(销量|销售额|成交额|收入|营收|GMV|gmv|订单|下单|支付|活跃|新增|留存|流失|复购|趋势|波动|变化|环比|同比|增长|下降)"); + + private static final Pattern METRIC_AMBIGUOUS_PATTERN = Pattern + .compile("(GMV|gmv|销售额|成交额|收入|营收|销量|订单量|转化率|留存|复购|活跃|新增|流失|客单价|毛利|利润|高价值用户)"); + + private static final Pattern METRIC_DEFINED_PATTERN = Pattern + .compile("(按.*口径|口径.*为|定义为|这里的.*指|仅统计|只统计|以.*为准|GMV=.*|销售额=.*|转化率=.*|留存=.*)"); + + private static final Pattern COMPARISON_PATTERN = Pattern.compile("(对比|比较|同比|环比|较|增长|下降|变化|波动)"); + + private static final Pattern COMPARISON_DEFINED_PATTERN = Pattern + .compile("(与[^,。;\\s]+(对比|比较)?|和[^,。;\\s]+(对比|比较)?|相比|较上周|较上月|较去年|较去年同期|较上期|同比|环比)"); + + private static final Pattern ORDERING_PATTERN = Pattern + .compile("(排序|排名|TOP\\s*\\d+|Top\\s*\\d+|top\\s*\\d+|前\\s*\\d+|后\\s*\\d+|最高|最低|最多|最少)"); + + private static final Pattern ORDERING_DEFINED_PATTERN = Pattern + .compile("(按[^,。;\\s]+(排序|排名)|基于[^,。;\\s]+(排序|排名)|按照[^,。;\\s]+(排序|排名)|按[^,。;\\s]+(升序|降序)|升序|降序|从高到低|从低到高)"); + + private static final Pattern ORDERING_METRIC_PATTERN = Pattern + .compile("(库存|价格|单价|销量|销售额|成交额|收入|营收|GMV|gmv|订单量|数量|下单量|支付金额)"); + + private static final Pattern STATIC_ANALYSIS_PATTERN = Pattern.compile("(库存|价格|单价|成本|邮箱|用户名|分类|状态值|字段|列|表结构)"); + + public QueryClarifyAssessment assess(@Nullable String query, @Nullable String humanFeedbackContent, + boolean clarifyCheckEnabled) { + String normalizedQuery = normalize(query); + String normalizedFeedback = normalize(humanFeedbackContent); + if (!clarifyCheckEnabled || !StringUtils.hasText(normalizedQuery)) { + return QueryClarifyAssessment.low(normalizedQuery, normalizedFeedback); + } + String evidenceText = normalizedQuery + + (StringUtils.hasText(normalizedFeedback) ? "\n" + normalizedFeedback : ""); + + List missingDimensions = new ArrayList<>(); + List followUpQuestions = new ArrayList<>(); + List suggestedAssumptions = new ArrayList<>(); + int score = 0; + + if (requiresTimeContext(normalizedQuery) && !hasExplicitTime(evidenceText)) { + missingDimensions.add("时间范围"); + followUpQuestions.add("你希望统计哪个时间范围?例如最近30天、本月、2025年6月。"); + suggestedAssumptions.add("按最近30天统计"); + score += 1; + } + if (requiresMetricDefinition(normalizedQuery) && !hasMetricDefinition(evidenceText)) { + missingDimensions.add("指标口径"); + followUpQuestions.add("你说的指标按什么口径计算?例如 GMV 是否只统计已完成订单、是否含退款。"); + suggestedAssumptions.add("按已完成订单口径统计,且不含退款"); + score += 2; + } + if (requiresComparisonTarget(normalizedQuery) && !hasComparisonTarget(evidenceText)) { + missingDimensions.add("对比对象"); + followUpQuestions.add("你希望和谁比较?例如与上周、上月、去年同期或某个具体对象对比。"); + suggestedAssumptions.add("与上周同期对比"); + score += 2; + } + if (requiresOrdering(normalizedQuery) && !hasOrderingBasis(evidenceText)) { + missingDimensions.add("排序依据"); + followUpQuestions.add("你希望按什么指标排序,以及升序还是降序?例如按销售额降序。"); + suggestedAssumptions.add("按销售额降序"); + score += 2; + } + + QueryClarifyRiskLevel riskLevel = score >= 3 ? QueryClarifyRiskLevel.HIGH + : score >= 1 ? QueryClarifyRiskLevel.MEDIUM : QueryClarifyRiskLevel.LOW; + boolean clarifyRequired = riskLevel == QueryClarifyRiskLevel.HIGH; + String summary = buildSummary(riskLevel, missingDimensions, normalizedFeedback); + String userMessage = clarifyRequired + ? buildClarifyMessage(riskLevel, followUpQuestions, suggestedAssumptions, normalizedFeedback) + : buildPassThroughMessage(riskLevel, missingDimensions, normalizedFeedback); + return new QueryClarifyAssessment(normalizedQuery, normalizedFeedback, riskLevel, clarifyRequired, + List.copyOf(missingDimensions), List.copyOf(followUpQuestions), List.copyOf(suggestedAssumptions), + summary, userMessage); + } + + private boolean requiresTimeContext(String query) { + if (!StringUtils.hasText(query)) { + return false; + } + if (STATIC_ANALYSIS_PATTERN.matcher(query).find() && !TIME_SENSITIVE_PATTERN.matcher(query).find()) { + return false; + } + return TIME_SENSITIVE_PATTERN.matcher(query).find() || COMPARISON_PATTERN.matcher(query).find() + || ORDERING_PATTERN.matcher(query).find(); + } + + private boolean requiresMetricDefinition(String query) { + return StringUtils.hasText(query) && METRIC_AMBIGUOUS_PATTERN.matcher(query).find(); + } + + private boolean requiresComparisonTarget(String query) { + return StringUtils.hasText(query) && COMPARISON_PATTERN.matcher(query).find(); + } + + private boolean requiresOrdering(String query) { + return StringUtils.hasText(query) && ORDERING_PATTERN.matcher(query).find(); + } + + private boolean hasExplicitTime(String text) { + return StringUtils.hasText(text) && EXPLICIT_TIME_PATTERN.matcher(text).find(); + } + + private boolean hasMetricDefinition(String text) { + return StringUtils.hasText(text) && METRIC_DEFINED_PATTERN.matcher(text).find(); + } + + private boolean hasComparisonTarget(String text) { + return StringUtils.hasText(text) && COMPARISON_DEFINED_PATTERN.matcher(text).find(); + } + + private boolean hasOrderingBasis(String text) { + if (!StringUtils.hasText(text)) { + return false; + } + if (ORDERING_DEFINED_PATTERN.matcher(text).find()) { + return true; + } + return ORDERING_METRIC_PATTERN.matcher(text).find() && ORDERING_PATTERN.matcher(text).find(); + } + + private String buildSummary(QueryClarifyRiskLevel riskLevel, List missingDimensions, + String feedbackContent) { + if (riskLevel == QueryClarifyRiskLevel.LOW) { + return StringUtils.hasText(feedbackContent) ? "已收到补充信息,当前歧义等级较低,可继续执行。" : "当前问题歧义等级较低,可直接继续执行。"; + } + String prefix = StringUtils.hasText(feedbackContent) ? "已收到补充信息,但仍存在歧义:" : "检测到高歧义问题:"; + return prefix + String.join("、", missingDimensions); + } + + private String buildClarifyMessage(QueryClarifyRiskLevel riskLevel, List followUpQuestions, + List suggestedAssumptions, String feedbackContent) { + StringBuilder builder = new StringBuilder(); + builder.append("为避免在口径不清时直接查库,当前问题需要先澄清。") + .append(System.lineSeparator()) + .append("riskLevel=") + .append(riskLevel.value()) + .append(System.lineSeparator()); + if (StringUtils.hasText(feedbackContent)) { + builder.append("已收到你的补充,但还不足以消除关键歧义。").append(System.lineSeparator()); + } + builder.append("请先补充以下信息:").append(System.lineSeparator()); + for (int i = 0; i < followUpQuestions.size(); i++) { + builder.append(i + 1).append(". ").append(followUpQuestions.get(i)).append(System.lineSeparator()); + } + if (!suggestedAssumptions.isEmpty()) { + builder.append("如果你接受默认假设,也可以直接回复:").append(System.lineSeparator()); + builder.append("按以下假设继续:").append(String.join(";", suggestedAssumptions)); + } + return builder.toString().trim(); + } + + private String buildPassThroughMessage(QueryClarifyRiskLevel riskLevel, List missingDimensions, + String feedbackContent) { + if (riskLevel == QueryClarifyRiskLevel.LOW) { + return StringUtils.hasText(feedbackContent) ? "已收到补充说明,继续按更新后的上下文执行。" : "当前问题歧义较低,可继续执行。"; + } + return "当前问题存在轻度歧义(" + String.join("、", missingDimensions) + "),继续执行前会优先保留你的显式口径。"; + } + + private String normalize(@Nullable String text) { + if (!StringUtils.hasText(text)) { + return ""; + } + return text.trim().replace('\u3000', ' '); + } + + public record QueryClarifyAssessment(String query, String feedbackContent, QueryClarifyRiskLevel riskLevel, + boolean clarifyRequired, List missingDimensions, List followUpQuestions, + List suggestedAssumptions, String summary, String userMessage) { + + private static QueryClarifyAssessment low(String query, String feedbackContent) { + return new QueryClarifyAssessment(query, feedbackContent, QueryClarifyRiskLevel.LOW, false, List.of(), + List.of(), List.of(), "当前问题歧义等级较低,可直接继续执行。", "当前问题歧义较低,可继续执行。"); + } + + public boolean shouldBlockExecution() { + return clarifyRequired && riskLevel == QueryClarifyRiskLevel.HIGH; + } + + public Map toMetadata() { + Map metadata = new LinkedHashMap<>(); + metadata.put("clarifyRequired", clarifyRequired); + metadata.put("riskLevel", riskLevel.value()); + metadata.put("missingDimensions", missingDimensions); + metadata.put("followUpQuestions", followUpQuestions); + metadata.put("suggestedAssumptions", suggestedAssumptions); + metadata.put("summary", summary); + return metadata; + } + } + + public enum QueryClarifyRiskLevel { + + LOW, + + MEDIUM, + + HIGH; + + public String value() { + return name().toLowerCase(Locale.ROOT); + } + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java index e8573b431..0cd03c32c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/SpringToolCallbackAgentAdapter.java @@ -15,12 +15,14 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.runtime; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.agentscope.core.message.TextBlock; import io.agentscope.core.message.ToolResultBlock; import io.agentscope.core.tool.AgentTool; import io.agentscope.core.tool.ToolCallParam; +import io.agentscope.core.tool.ToolExecutionContext; import java.util.LinkedHashMap; import java.util.Map; import lombok.RequiredArgsConstructor; @@ -76,7 +78,7 @@ private ToolResultBlock invoke(ToolCallParam toolCallParam) throws Exception { } catch (Exception ex) { log.error("Spring AI tool callback execution failed. tool={}", getName(), ex); - return ToolResultBlock.error(ex.getMessage() == null ? "Tool execution failed." : ex.getMessage()) + return ToolResultBlock.error(ex.getMessage() == null ? "工具执行失败。" : ex.getMessage()) .withIdAndName(toolCallParam.getToolUseBlock().getId(), getName()); } } @@ -85,6 +87,15 @@ private ToolContext toToolContext(ToolCallParam toolCallParam) { Map contextMap = new LinkedHashMap<>(); if (toolCallParam.getContext() != null) { contextMap.put("agentScopeContext", toolCallParam.getContext()); + ToolExecutionContext toolExecutionContext = toolCallParam.getContext(); + AgentRequest agentRequest = toolExecutionContext.get("agentRequest", AgentRequest.class); + if (agentRequest != null) { + contextMap.put("graphRequest", agentRequest); + } + AgentRuntimeRequestMetadata requestMetadata = toolExecutionContext.get(AgentRuntimeRequestMetadata.class); + if (requestMetadata != null) { + contextMap.put("runtimeRequestMetadata", requestMetadata); + } } if (toolCallParam.getAgent() != null) { contextMap.put("agentScopeAgent", toolCallParam.getAgent()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java new file mode 100644 index 000000000..68901eaf6 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/runtime/ToolContextRequestResolver.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.runtime; + +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import io.agentscope.core.tool.ToolExecutionContext; +import java.util.Map; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +public final class ToolContextRequestResolver { + + private ToolContextRequestResolver() { + } + + @Nullable + public static AgentRequest resolveGraphRequest(@Nullable ToolContext toolContext) { + if (toolContext == null || toolContext.getContext() == null) { + return null; + } + Map context = toolContext.getContext(); + Object graphRequest = context.get("graphRequest"); + if (graphRequest instanceof AgentRequest request) { + return request; + } + Object agentScopeContext = context.get("agentScopeContext"); + if (agentScopeContext instanceof ToolExecutionContext toolExecutionContext) { + AgentRequest request = toolExecutionContext.get("graphRequest", AgentRequest.class); + if (request != null) { + return request; + } + AgentRequest metadataRequest = fromMetadata(toolExecutionContext.get(AgentRuntimeRequestMetadata.class)); + if (metadataRequest != null) { + return metadataRequest; + } + } + Object runtimeRequestMetadata = context.get("runtimeRequestMetadata"); + if (runtimeRequestMetadata instanceof AgentRuntimeRequestMetadata metadata) { + return fromMetadata(metadata); + } + return null; + } + + @Nullable + private static AgentRequest fromMetadata(@Nullable AgentRuntimeRequestMetadata metadata) { + if (metadata == null || !StringUtils.hasText(metadata.threadId()) + || !StringUtils.hasText(metadata.runtimeRequestId())) { + return null; + } + return AgentRequest.builder() + .agentId(metadata.agentId()) + .threadId(metadata.threadId()) + .runtimeRequestId(metadata.runtimeRequestId()) + .humanFeedback(metadata.humanFeedback()) + .humanFeedbackContent(metadata.humanFeedbackContent()) + .build(); + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java index 764d3b3c8..eb810e367 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/AgentService.java @@ -15,16 +15,14 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.service; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import org.springframework.http.codec.ServerSentEvent; import reactor.core.publisher.Sinks; public interface AgentService { - String nl2sql(String naturalQuery, String agentId); - - void graphStreamProcess(Sinks.Many> sink, GraphRequest graphRequest); + void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest); void stopStreamProcessing(String threadId, String runtimeRequestId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java index 43f7e707c..7ab510c12 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/service/impl/AiAgentRuntimeServiceImpl.java @@ -15,34 +15,55 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.service.impl; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentScopeMemoryFactory; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeEventPublisher; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentRuntimeExtensionFactory; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.PreparedMemory; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.QueryClarifyService.QueryClarifyAssessment; import com.alibaba.cloud.ai.dataagent.agentscope.runtime.AgentScopeToolkitFactory; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentScopeModelFactory; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentService; -import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentSessionRegistry; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentRuntimeRegistry; +import com.alibaba.cloud.ai.dataagent.agentscope.session.AgentScopeNativeSessionService; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRunContext; import com.alibaba.cloud.ai.dataagent.agentscope.template.AgentRuntimeExtensions; import com.alibaba.cloud.ai.dataagent.agentscope.template.ManagedAgent; import com.alibaba.cloud.ai.dataagent.agentscope.template.ManagedAgentRegistry; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import com.alibaba.cloud.ai.dataagent.constant.AgentRuntimeConstant; import com.alibaba.cloud.ai.dataagent.enums.ModelType; import com.alibaba.cloud.ai.dataagent.enums.TextType; import com.alibaba.cloud.ai.dataagent.dto.ModelConfigDTO; import com.alibaba.cloud.ai.dataagent.entity.Agent; +import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; +import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import com.alibaba.cloud.ai.dataagent.service.aimodelconfig.DynamicModelFactory; import com.alibaba.cloud.ai.dataagent.service.aimodelconfig.ModelConfigDataService; +import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; +import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.agentscope.core.memory.InMemoryMemory; import io.agentscope.core.message.Msg; +import io.agentscope.core.message.MsgRole; import io.agentscope.core.model.Model; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.concurrent.CancellationException; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.tool.ToolCallback; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.http.codec.ServerSentEvent; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; @@ -61,13 +82,17 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private static final String RUNTIME_NODE_NAME = "AgentScopeRuntime"; + private static final String ANSWER_EXPLAIN_MESSAGE_TYPE = "answer-explain"; + private static final String STREAM_EVENT_MESSAGE = "message"; + private static final String ROOT_SPAN_NAME = "data-agent.agent.run"; + private static final String AGENT_STATUS_PUBLISHED = "published"; private static final String AGENT_STATUS_OFFLINE = "offline"; - private final AgentSessionRegistry sessionRegistry; + private final AgentRuntimeRegistry runtimeRegistry; private final ModelConfigDataService modelConfigDataService; @@ -81,31 +106,34 @@ public class AiAgentRuntimeServiceImpl implements AgentService { private final AgentRuntimeExtensionFactory agentRuntimeExtensionFactory; + private final AgentScopeMemoryFactory agentScopeMemoryFactory; + private final com.alibaba.cloud.ai.dataagent.service.agent.AgentService agentService; - @Override - public String nl2sql(String naturalQuery, String agentId) { - log.info("NL2SQL runtime invoked for agentId={}", agentId); - GraphRequest request = GraphRequest.builder().agentId(agentId).query(naturalQuery).nl2sqlOnly(true).build(); - initializeRuntimeRequest(request); - sessionRegistry.register(request.getThreadId(), request.getRuntimeRequestId()); - try { - return executeAgent(request); - } - finally { - sessionRegistry.finish(request.getThreadId(), request.getRuntimeRequestId()); - } - } + @Qualifier("agentScopeTracer") + private final Tracer tracer; + + private final AnswerTraceExplainStore answerTraceExplainStore; + + private final ChatSessionService chatSessionService; + + private final ChatMessageService chatMessageService; + + private final ObjectMapper objectMapper; + + private final QueryClarifyService queryClarifyService; + + private final AgentScopeNativeSessionService nativeSessionService; @Override - public void graphStreamProcess(Sinks.Many> sink, GraphRequest graphRequest) { - initializeRuntimeRequest(graphRequest); - String threadId = graphRequest.getThreadId(); - String runtimeRequestId = graphRequest.getRuntimeRequestId(); + public void graphStreamProcess(Sinks.Many> sink, AgentRequest agentRequest) { + initializeRuntimeRequest(agentRequest); + String threadId = agentRequest.getThreadId(); + String runtimeRequestId = agentRequest.getRuntimeRequestId(); StreamTextTracker streamTextTracker = new StreamTextTracker(); - sessionRegistry.register(threadId, runtimeRequestId); + runtimeRegistry.register(threadId, runtimeRequestId); AgentRuntimeEventPublisher eventPublisher = response -> { - if (!sessionRegistry.isActive(threadId, runtimeRequestId)) { + if (!runtimeRegistry.isActive(threadId, runtimeRequestId)) { return; } if (response != null && response.getTextType() == TextType.TEXT @@ -115,27 +143,27 @@ public void graphStreamProcess(Sinks.Many> si sink.tryEmitNext(ServerSentEvent.builder(response).event(STREAM_EVENT_MESSAGE).build()); }; - Mono.fromCallable(() -> executeAgent(graphRequest, eventPublisher)) - .doFinally(signalType -> sessionRegistry.finish(threadId, runtimeRequestId)) + Mono.fromCallable(() -> executeAgent(agentRequest, eventPublisher)) + .doFinally(signalType -> runtimeRegistry.finish(threadId, runtimeRequestId)) .subscribeOn(Schedulers.boundedElastic()) - .subscribe(result -> emitSuccess(sink, graphRequest, result, streamTextTracker), - error -> emitError(sink, graphRequest, error)); + .subscribe(result -> emitSuccess(sink, agentRequest, result, streamTextTracker), + error -> emitError(sink, agentRequest, error)); } @Override public void stopStreamProcessing(String threadId, String runtimeRequestId) { - sessionRegistry.markCancelled(threadId, runtimeRequestId); + runtimeRegistry.markCancelled(threadId, runtimeRequestId); } - private void emitSuccess(Sinks.Many> sink, GraphRequest request, String result, + private void emitSuccess(Sinks.Many> sink, AgentRequest request, String result, StreamTextTracker streamTextTracker) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); - if (!sessionRegistry.isActive(threadId, runtimeRequestId)) { + if (!runtimeRegistry.isActive(threadId, runtimeRequestId)) { return; } if (shouldEmitFinalResponse(result, streamTextTracker)) { - GraphNodeResponse response = GraphNodeResponse.builder() + AgentResponse response = AgentResponse.builder() .agentId(request.getAgentId()) .threadId(threadId) .nodeName(RUNTIME_NODE_NAME) @@ -144,97 +172,221 @@ private void emitSuccess(Sinks.Many> sink, Gr .build(); sink.tryEmitNext(ServerSentEvent.builder(response).event(STREAM_EVENT_MESSAGE).build()); } - sink.tryEmitNext(ServerSentEvent.builder(GraphNodeResponse.complete(request.getAgentId(), threadId)) + sink.tryEmitNext(ServerSentEvent.builder(AgentResponse.complete(request.getAgentId(), threadId)) .event(STREAM_EVENT_COMPLETE) .build()); sink.tryEmitComplete(); } private boolean shouldEmitFinalResponse(String result, StreamTextTracker streamTextTracker) { - return StringUtils.hasText(result) && !streamTextTracker.matchesAnyNodeAccumulation(result); + return StringUtils.hasText(result) && !streamTextTracker.containsFinalAnswer(result); } - private void emitError(Sinks.Many> sink, GraphRequest request, Throwable error) { + private void emitError(Sinks.Many> sink, AgentRequest request, Throwable error) { String threadId = request.getThreadId(); String runtimeRequestId = request.getRuntimeRequestId(); - if (sessionRegistry.isCancelled(threadId, runtimeRequestId)) { + if (runtimeRegistry.isCancelled(threadId, runtimeRequestId)) { log.info("AgentScope runtime cancelled, suppress error propagation. threadId={}, runtimeRequestId={}", threadId, runtimeRequestId); return; } log.error("AgentScope runtime failed, threadId={}", threadId, error); - if (sessionRegistry.isActive(threadId, runtimeRequestId)) { - String message = error.getMessage() == null ? "AgentScope runtime failed." : error.getMessage(); - sink.tryEmitNext(ServerSentEvent.builder(GraphNodeResponse.error(request.getAgentId(), threadId, message)) + if (runtimeRegistry.isActive(threadId, runtimeRequestId)) { + String message = error.getMessage() == null ? "AgentScope 运行失败。" : error.getMessage(); + sink.tryEmitNext(ServerSentEvent.builder(AgentResponse.error(request.getAgentId(), threadId, message)) .event(STREAM_EVENT_ERROR) .build()); sink.tryEmitComplete(); } } - private String executeAgent(GraphRequest request) { + private String executeAgent(AgentRequest request) { return executeAgent(request, null); } - private void initializeRuntimeRequest(GraphRequest request) { + private void initializeRuntimeRequest(AgentRequest request) { if (!StringUtils.hasText(request.getThreadId())) { - request.setThreadId(UUID.randomUUID().toString()); + throw new IllegalArgumentException("threadId must not be empty"); } if (!StringUtils.hasText(request.getRuntimeRequestId())) { request.setRuntimeRequestId(UUID.randomUUID().toString()); } } - private String executeAgent(GraphRequest request, AgentRuntimeEventPublisher eventPublisher) { - sessionRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); + private String executeAgent(AgentRequest request, AgentRuntimeEventPublisher eventPublisher) { + runtimeRegistry.markRunning(request.getThreadId(), request.getRuntimeRequestId(), Thread.currentThread()); + answerTraceExplainStore.openScope(request); + Span rootSpan = startRuntimeSpan(request); + PreparedMemory preparedMemory = null; try { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { - return ""; - } - Agent managedAgentConfig = resolveManagedAgent(request.getAgentId()); - ModelConfigDTO modelConfig = modelConfigDataService.getActiveConfigByType(ModelType.CHAT); - validateModelConfig(modelConfig); - Map toolCallbacks = agentScopeToolkitFactory.getToolCallbacks(request.getAgentId()); - Model model = agentScopeModelFactory.create(dynamicModelFactory.createChatModel(modelConfig), - modelConfig.getModelName(), toolCallbacks); - ManagedAgent managedAgent = managedAgentRegistry.getRequired(); - AgentRuntimeExtensions runtimeExtensions = agentRuntimeExtensionFactory.create(request, eventPublisher, - toolCallbacks); - Msg response; - try { - response = managedAgent.run(new AgentRunContext(request.getAgentId(), request.getThreadId(), model, - resolveManagedSystemPrompt(managedAgentConfig, request.getAgentId()), buildUserPrompt(request), - AgentRuntimeConstant.AGENT_CALL_TIMEOUT, runtimeExtensions)); - } - catch (RuntimeException ex) { - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) - && isInterruptedCancellation(ex)) { - Thread.interrupted(); - log.info("Agent execution interrupted by cancellation, threadId={}, runtimeRequestId={}", - request.getThreadId(), request.getRuntimeRequestId()); + try (Scope ignored = rootSpan.makeCurrent()) { + preparedMemory = agentScopeMemoryFactory.create(request); + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + rootSpan.setStatus(StatusCode.OK, "cancelled"); return ""; } - throw ex; - } - if (sessionRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { - return ""; + boolean clarifyCheckEnabled = request.isClarifyCheckEnabled() + || StringUtils.hasText(request.getHumanFeedbackContent()); + QueryClarifyAssessment clarifyAssessment = queryClarifyService.assess(request.getQuery(), + request.getHumanFeedbackContent(), clarifyCheckEnabled); + answerTraceExplainStore.recordClarifyAssessment(request, clarifyAssessment); + rootSpan.setAttribute("dataagent.query_clarify.enabled", clarifyCheckEnabled); + rootSpan.setAttribute("dataagent.query_clarify.risk_level", clarifyAssessment.riskLevel().value()); + rootSpan.setAttribute("dataagent.query_clarify.blocked", clarifyAssessment.shouldBlockExecution()); + if (clarifyAssessment.shouldBlockExecution()) { + return blockForClarification(request, eventPublisher, rootSpan, clarifyAssessment, preparedMemory); + } + Agent managedAgentConfig = resolveManagedAgent(request.getAgentId()); + ModelConfigDTO modelConfig = modelConfigDataService.getActiveConfigByType(ModelType.CHAT); + validateModelConfig(modelConfig); + Map toolCallbacks = agentScopeToolkitFactory + .getToolCallbacks(request.getAgentId()); + Model model = agentScopeModelFactory.create(dynamicModelFactory.createChatModel(modelConfig), + modelConfig.getModelName(), toolCallbacks); + ManagedAgent managedAgent = managedAgentRegistry.getRequired(); + AgentRuntimeExtensions runtimeExtensions = agentRuntimeExtensionFactory.create(request, eventPublisher, + toolCallbacks, preparedMemory); + Msg response; + try { + response = managedAgent.run(new AgentRunContext(request.getAgentId(), request.getThreadId(), model, + resolveManagedSystemPrompt(managedAgentConfig, request.getAgentId()), + buildUserPrompt(request), AgentRuntimeConstant.AGENT_CALL_TIMEOUT, runtimeExtensions)); + } + catch (RuntimeException ex) { + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId()) + && isInterruptedCancellation(ex)) { + Thread.interrupted(); + rootSpan.setStatus(StatusCode.OK, "cancelled"); + log.info("Agent execution interrupted by cancellation, threadId={}, runtimeRequestId={}", + request.getThreadId(), request.getRuntimeRequestId()); + return ""; + } + throw ex; + } + if (runtimeRegistry.isCancelled(request.getThreadId(), request.getRuntimeRequestId())) { + rootSpan.setStatus(StatusCode.OK, "cancelled"); + return ""; + } + persistNativeMemorySafely(request, runtimeExtensions.memory()); + rootSpan.setStatus(StatusCode.OK); + String answer = extractText(response); + answerTraceExplainStore.recordFinalAnswer(answer); + persistAnswerExplainSnapshot(request); + mirrorExplainSummary(rootSpan, request); + return answer; } - return extractText(response); + } + catch (RuntimeException ex) { + recordRuntimeFailure(rootSpan, ex); + throw ex; } finally { - sessionRegistry.clearRunning(request.getThreadId(), request.getRuntimeRequestId()); + rootSpan.end(); + runtimeRegistry.clearRunning(request.getThreadId(), request.getRuntimeRequestId()); + answerTraceExplainStore.closeScope(); } } + private Span startRuntimeSpan(AgentRequest request) { + Span span = tracer.spanBuilder(ROOT_SPAN_NAME).startSpan(); + span.setAttribute(SessionTraceStore.ATTR_THREAD_ID, request.getThreadId()); + span.setAttribute(SessionTraceStore.ATTR_RUNTIME_REQUEST_ID, request.getRuntimeRequestId()); + span.setAttribute(SessionTraceStore.ATTR_AGENT_ID, request.getAgentId() == null ? "" : request.getAgentId()); + span.setAttribute("dataagent.runtime.human_feedback", request.isHumanFeedback()); + return span; + } + + private void recordRuntimeFailure(Span rootSpan, Throwable throwable) { + if (rootSpan == null) { + return; + } + rootSpan.setStatus(StatusCode.ERROR, + throwable.getMessage() == null ? "runtime failed" : throwable.getMessage()); + rootSpan.recordException(throwable); + } + + private String blockForClarification(AgentRequest request, AgentRuntimeEventPublisher eventPublisher, Span rootSpan, + QueryClarifyAssessment clarifyAssessment, PreparedMemory preparedMemory) { + String clarifyText = clarifyAssessment.userMessage(); + appendClarifyTurn(preparedMemory, request, clarifyText); + persistNativeMemorySafely(request, preparedMemory == null ? null : preparedMemory.memory()); + answerTraceExplainStore.recordFinalAnswer(clarifyText); + persistAnswerExplainSnapshot(request); + mirrorExplainSummary(rootSpan, request); + rootSpan.setStatus(StatusCode.OK, "clarify required"); + if (eventPublisher != null) { + Map metadata = new LinkedHashMap<>(clarifyAssessment.toMetadata()); + metadata.put("originalQuery", request.getQuery()); + eventPublisher.publish(AgentResponse.builder() + .agentId(request.getAgentId()) + .threadId(request.getThreadId()) + .nodeName(RUNTIME_NODE_NAME) + .textType(TextType.TEXT) + .text(clarifyText) + .metadata(metadata) + .build()); + } + return clarifyText; + } + + private void mirrorExplainSummary(Span rootSpan, AgentRequest request) { + if (rootSpan == null || request == null) { + return; + } + answerTraceExplainStore.getMirrorSummary(request.getThreadId(), request.getRuntimeRequestId()) + .ifPresent(summary -> { + rootSpan.setAttribute("dataagent.answer.explain.available", true); + rootSpan.setAttribute("dataagent.answer.explain.tool_step_count", summary.getToolStepCount()); + rootSpan.setAttribute("dataagent.answer.explain.semantic_hit_count", summary.getSemanticHitCount()); + rootSpan.setAttribute("dataagent.answer.explain.knowledge_hit_count", summary.getKnowledgeHitCount()); + if (StringUtils.hasText(summary.getDatasource())) { + rootSpan.setAttribute("dataagent.answer.explain.datasource", summary.getDatasource()); + } + }); + } + + private void persistAnswerExplainSnapshot(AgentRequest request) { + if (request == null || !StringUtils.hasText(request.getThreadId()) + || !StringUtils.hasText(request.getRuntimeRequestId())) { + return; + } + if (chatSessionService.findBySessionId(request.getThreadId()) == null) { + return; + } + answerTraceExplainStore.getExplain(request.getThreadId(), request.getRuntimeRequestId()).ifPresent(explain -> { + try { + chatMessageService.saveMessage(ChatMessage.builder() + .sessionId(request.getThreadId()) + .role("system") + .content(objectMapper.writeValueAsString(explain)) + .messageType(ANSWER_EXPLAIN_MESSAGE_TYPE) + .metadata(buildAnswerExplainMetadata(request)) + .build()); + } + catch (Exception ex) { + log.warn("Failed to persist answer explain snapshot. sessionId={}, runtimeRequestId={}", + request.getThreadId(), request.getRuntimeRequestId(), ex); + } + }); + } + + private String buildAnswerExplainMetadata(AgentRequest request) throws Exception { + Map metadata = new LinkedHashMap<>(); + metadata.put("kind", "answer-explain"); + metadata.put("runtimeRequestId", request.getRuntimeRequestId()); + metadata.put("visibility", "system-hidden"); + return objectMapper.writeValueAsString(metadata); + } + private void validateModelConfig(ModelConfigDTO modelConfig) { if (modelConfig == null) { - throw new IllegalStateException("No active CHAT model configured. Please configure it in the dashboard."); + throw new IllegalStateException("当前未配置可用的 CHAT 模型,请先在控制台完成配置。"); } if (!StringUtils.hasText(modelConfig.getApiKey())) { - throw new IllegalStateException("Active CHAT model apiKey is empty."); + throw new IllegalStateException("当前活动 CHAT 模型的 apiKey 为空。"); } if (!StringUtils.hasText(modelConfig.getModelName())) { - throw new IllegalStateException("Active CHAT model modelName is empty."); + throw new IllegalStateException("当前活动 CHAT 模型的 modelName 为空。"); } } @@ -257,10 +409,106 @@ private String resolveManagedSystemPrompt(Agent agent, String requestAgentId) { return agent.getPrompt(); } - private String buildUserPrompt(GraphRequest request) { + private String buildUserPrompt(AgentRequest request) { return request.getQuery() == null ? "" : request.getQuery(); } + private void persistNativeMemorySafely(AgentRequest request, io.agentscope.core.memory.Memory memory) { + if (!(memory instanceof InMemoryMemory inMemoryMemory) || request == null + || !StringUtils.hasText(request.getThreadId())) { + return; + } + normalizeMemoryForPersistence(request, inMemoryMemory); + try { + nativeSessionService.saveMemory(inMemoryMemory, request.getThreadId()); + } + catch (RuntimeException ex) { + log.warn("Failed to persist AgentScope native memory. threadId={}, runtimeRequestId={}", + request.getThreadId(), request.getRuntimeRequestId(), ex); + } + } + + private void appendClarifyTurn(PreparedMemory preparedMemory, AgentRequest request, String clarifyText) { + if (preparedMemory == null || request == null) { + return; + } + InMemoryMemory memory = preparedMemory.memory(); + if (memory == null) { + return; + } + String userInput = resolvePersistedUserInput(request); + if (StringUtils.hasText(userInput)) { + memory.addMessage(Msg.builder().name("user").role(MsgRole.USER).textContent(userInput).build()); + } + if (StringUtils.hasText(clarifyText)) { + memory.addMessage(Msg.builder().name("assistant").role(MsgRole.ASSISTANT).textContent(clarifyText).build()); + } + } + + private void normalizeMemoryForPersistence(AgentRequest request, InMemoryMemory memory) { + if (request == null || memory == null || !StringUtils.hasText(request.getHumanFeedbackContent()) + || !StringUtils.hasText(request.getQuery())) { + return; + } + List originalMessages = memory.getMessages(); + if (originalMessages == null || originalMessages.isEmpty()) { + return; + } + List normalizedMessages = new ArrayList<>(originalMessages); + int replacementIndex = findReplayQueryIndex(normalizedMessages, request.getQuery()); + if (replacementIndex < 0) { + return; + } + Msg replayMessage = normalizedMessages.get(replacementIndex); + normalizedMessages.set(replacementIndex, + Msg.builder() + .name(resolveMsgName(replayMessage, "user")) + .role(MsgRole.USER) + .textContent(request.getHumanFeedbackContent()) + .build()); + if (Objects.equals(originalMessages, normalizedMessages)) { + return; + } + memory.clear(); + normalizedMessages.forEach(memory::addMessage); + } + + private int findReplayQueryIndex(List messages, String originalQuery) { + int firstMatchIndex = -1; + for (int i = 0; i < messages.size(); i++) { + Msg message = messages.get(i); + if (message == null || message.getRole() != MsgRole.USER) { + continue; + } + if (!Objects.equals(message.getTextContent(), originalQuery)) { + continue; + } + if (firstMatchIndex < 0) { + firstMatchIndex = i; + continue; + } + return i; + } + return -1; + } + + private String resolveMsgName(Msg message, String fallback) { + if (message != null && StringUtils.hasText(message.getName())) { + return message.getName(); + } + return fallback; + } + + private String resolvePersistedUserInput(AgentRequest request) { + if (request == null) { + return null; + } + if (StringUtils.hasText(request.getHumanFeedbackContent())) { + return request.getHumanFeedbackContent(); + } + return request.getQuery(); + } + private void validateAgentStatus(Agent agent, String requestAgentId) { if (agent == null) { return; @@ -314,25 +562,59 @@ private static final class StreamTextTracker { private final Map accumulatedByNode = new LinkedHashMap<>(); + private final Map lastTextByNode = new LinkedHashMap<>(); + synchronized void record(String nodeName, String text) { if (!StringUtils.hasText(text)) { return; } - accumulatedByNode.computeIfAbsent(nodeName, key -> new StringBuilder()).append(text); + String normalizedNodeName = StringUtils.hasText(nodeName) ? nodeName : ""; + accumulatedByNode.computeIfAbsent(normalizedNodeName, key -> new StringBuilder()).append(text); + lastTextByNode.put(normalizedNodeName, text); } - synchronized boolean matchesAnyNodeAccumulation(String candidate) { + synchronized boolean containsFinalAnswer(String candidate) { if (!StringUtils.hasText(candidate)) { return false; } + String normalizedCandidate = normalize(candidate); + if (!StringUtils.hasText(normalizedCandidate)) { + return false; + } for (StringBuilder accumulated : accumulatedByNode.values()) { - if (candidate.equals(accumulated.toString())) { + String normalizedAccumulated = normalize(accumulated.toString()); + if (matchesFinalAnswer(normalizedAccumulated, normalizedCandidate)) { + return true; + } + } + for (String lastText : lastTextByNode.values()) { + String normalizedLastText = normalize(lastText); + if (matchesFinalAnswer(normalizedLastText, normalizedCandidate)) { return true; } } return false; } + private boolean matchesFinalAnswer(String existingText, String candidate) { + if (!StringUtils.hasText(existingText) || !StringUtils.hasText(candidate)) { + return false; + } + return existingText.equals(candidate) || existingText.endsWith(candidate) + || candidate.endsWith(existingText); + } + + private String normalize(String text) { + if (!StringUtils.hasText(text)) { + return ""; + } + return text.replace("\r\n", "\n") + .replace('\r', '\n') + .replaceAll("[ \\t\\x0B\\f]+", " ") + .replaceAll(" *\\n *", "\n") + .trim(); + } + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java similarity index 96% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java index 3b551852c..3b472bed4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentSessionRegistry.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentRuntimeRegistry.java @@ -21,7 +21,7 @@ import org.springframework.stereotype.Component; @Component -public class AgentSessionRegistry { +public class AgentRuntimeRegistry { private final ConcurrentHashMap> requestStatesByThreadId = new ConcurrentHashMap<>(); @@ -76,7 +76,7 @@ public void finish(String threadId, String runtimeRequestId) { private RequestExecutionState getOrCreateState(String threadId, String runtimeRequestId) { if (threadId == null || threadId.isBlank() || runtimeRequestId == null || runtimeRequestId.isBlank()) { - throw new IllegalArgumentException("threadId and runtimeRequestId must not be blank"); + throw new IllegalArgumentException("threadId 和 runtimeRequestId 不能为空"); } ConcurrentHashMap states = requestStatesByThreadId.computeIfAbsent(threadId, key -> new ConcurrentHashMap<>()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java new file mode 100644 index 000000000..3ec32090c --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeMysqlSession.java @@ -0,0 +1,167 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.session; + +import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; +import com.alibaba.cloud.ai.dataagent.mapper.ChatMessageMapper; +import com.alibaba.cloud.ai.dataagent.mapper.ChatSessionMapper; +import io.agentscope.core.session.Session; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.State; +import io.agentscope.core.state.SessionKey; +import io.agentscope.core.util.JsonCodec; +import io.agentscope.core.util.JsonUtils; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +@Slf4j +@Component +@RequiredArgsConstructor +public class AgentScopeMysqlSession implements Session { + + static final String STATE_MESSAGE_TYPE_PREFIX = "agentscope-state:"; + + private static final String STATE_KIND = "agentscope-session-state"; + + private static final String STATE_VISIBILITY = "system-hidden"; + + private static final String STATE_ROLE = "system"; + + private final ChatSessionMapper chatSessionMapper; + + private final ChatMessageMapper chatMessageMapper; + + private final JsonCodec jsonCodec = JsonUtils.getJsonCodec(); + + @Override + public void save(SessionKey sessionKey, String moduleName, State state) { + String sessionId = sessionId(sessionKey); + assertManagedChatSession(sessionId); + String messageType = messageType(moduleName); + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId, messageType); + chatMessageMapper.insert(buildStateMessage(sessionId, messageType, moduleName, state.getClass().getName(), 0, + jsonCodec.toPrettyJson(state))); + } + + @Override + public void save(SessionKey sessionKey, String moduleName, List states) { + String sessionId = sessionId(sessionKey); + assertManagedChatSession(sessionId); + String messageType = messageType(moduleName); + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId, messageType); + if (states == null || states.isEmpty()) { + return; + } + for (int index = 0; index < states.size(); index++) { + State state = states.get(index); + if (state == null) { + continue; + } + chatMessageMapper.insert(buildStateMessage(sessionId, messageType, moduleName, state.getClass().getName(), + index, jsonCodec.toJson(state))); + } + } + + @Override + public Optional get(SessionKey sessionKey, String moduleName, Class stateClass) { + List rows = chatMessageMapper.selectStateBySessionIdAndMessageType(sessionId(sessionKey), + messageType(moduleName)); + if (rows.isEmpty()) { + return Optional.empty(); + } + ChatMessage latestRow = rows.get(rows.size() - 1); + return Optional.ofNullable(jsonCodec.fromJson(latestRow.getContent(), stateClass)); + } + + @Override + public List getList(SessionKey sessionKey, String moduleName, Class stateClass) { + return chatMessageMapper.selectStateBySessionIdAndMessageType(sessionId(sessionKey), messageType(moduleName)) + .stream() + .map(ChatMessage::getContent) + .map(content -> jsonCodec.fromJson(content, stateClass)) + .collect(Collectors.toList()); + } + + @Override + public boolean exists(SessionKey sessionKey) { + return chatMessageMapper.countAgentScopeStateBySessionId(sessionId(sessionKey)) > 0; + } + + @Override + public void delete(SessionKey sessionKey) { + chatMessageMapper.deleteAgentScopeStateBySessionId(sessionId(sessionKey)); + } + + @Override + public void delete(SessionKey sessionKey, String moduleName) { + chatMessageMapper.deleteBySessionIdAndMessageType(sessionId(sessionKey), messageType(moduleName)); + } + + @Override + public Set listSessionKeys() { + return chatMessageMapper.selectSessionIdsWithAgentScopeState() + .stream() + .map(SimpleSessionKey::of) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + public Mono clearAllSessions() { + return Mono.fromCallable(chatMessageMapper::deleteAllAgentScopeStateMessages) + .subscribeOn(Schedulers.boundedElastic()); + } + + private void assertManagedChatSession(String sessionId) { + if (chatSessionMapper.selectBySessionId(sessionId) == null) { + throw new IllegalStateException( + "AgentScope session persistence requires an existing chat_session. sessionId=" + sessionId); + } + } + + private String sessionId(SessionKey sessionKey) { + return sessionKey.toIdentifier(); + } + + private String messageType(String moduleName) { + String encodedModuleName = java.util.Base64.getUrlEncoder() + .withoutPadding() + .encodeToString(moduleName.getBytes(StandardCharsets.UTF_8)); + return STATE_MESSAGE_TYPE_PREFIX + encodedModuleName; + } + + private ChatMessage buildStateMessage(String sessionId, String messageType, String moduleName, String stateClass, + int sequence, String content) { + Map metadata = Map.of("kind", STATE_KIND, "visibility", STATE_VISIBILITY, "moduleName", + moduleName, "stateClass", stateClass, "sequence", sequence); + return ChatMessage.builder() + .sessionId(sessionId) + .role(STATE_ROLE) + .content(content) + .messageType(messageType) + .metadata(jsonCodec.toJson(metadata)) + .build(); + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java new file mode 100644 index 000000000..64ad96225 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/session/AgentScopeNativeSessionService.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.session; + +import io.agentscope.core.memory.Memory; +import io.agentscope.core.session.Session; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.SessionKey; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.Objects; +import java.util.regex.Pattern; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Slf4j +@Component +public class AgentScopeNativeSessionService { + + private static final Pattern SESSION_ID_PATTERN = Pattern.compile("[A-Za-z0-9_-]+"); + + private final Session session; + + public AgentScopeNativeSessionService(AgentScopeMysqlSession session) { + this.session = session; + } + + public boolean loadMemoryIfExists(Memory memory, String sessionId) { + Objects.requireNonNull(memory, "memory must not be null"); + if (!StringUtils.hasText(sessionId)) { + return false; + } + try { + return memory.loadIfExists(session, sessionKey(sessionId)); + } + catch (RuntimeException ex) { + log.warn("Failed to load AgentScope native session state from MySQL. sessionId={}", sessionId, ex); + return false; + } + } + + public void saveMemory(Memory memory, String sessionId) { + Objects.requireNonNull(memory, "memory must not be null"); + if (!StringUtils.hasText(sessionId)) { + return; + } + memory.saveTo(session, sessionKey(sessionId)); + } + + public void deleteSessionState(String sessionId) { + if (!StringUtils.hasText(sessionId)) { + return; + } + try { + SessionKey sessionKey = sessionKey(sessionId); + if (session.exists(sessionKey)) { + session.delete(sessionKey); + } + } + catch (RuntimeException ex) { + throw new IllegalStateException("Failed to delete AgentScope session state: " + sessionId, ex); + } + } + + public void deleteSessionStates(Collection sessionIds) { + if (sessionIds == null || sessionIds.isEmpty()) { + return; + } + for (String sessionId : new LinkedHashSet<>(sessionIds)) { + deleteSessionState(sessionId); + } + } + + private SessionKey sessionKey(String sessionId) { + return SimpleSessionKey.of(normalizeSessionId(sessionId)); + } + + private String normalizeSessionId(String sessionId) { + if (!StringUtils.hasText(sessionId)) { + throw new IllegalArgumentException("sessionId 不能为空"); + } + String normalized = sessionId.trim(); + if (!SESSION_ID_PATTERN.matcher(normalized).matches()) { + throw new IllegalArgumentException("sessionId 包含非法字符: " + sessionId); + } + return normalized; + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java index 7f5faec29..ae39536d7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/CommonAgent.java @@ -21,7 +21,6 @@ import io.agentscope.core.hook.Hook; import io.agentscope.core.message.Msg; import io.agentscope.core.message.MsgRole; -import io.agentscope.core.model.ExecutionConfig; import java.time.Duration; import java.util.List; import java.util.Objects; @@ -35,18 +34,6 @@ public class CommonAgent implements ManagedAgent { private static final String SYSTEM_PROMPT = PromptLoader.loadPrompt(AGENT_TYPE, "md"); - private static final int DEFAULT_MAX_ITERS = 10; - - private static final ExecutionConfig DEFAULT_MODEL_EXECUTION_CONFIG = ExecutionConfig.builder() - .timeout(Duration.ofMinutes(2)) - .maxAttempts(2) - .build(); - - private static final ExecutionConfig DEFAULT_TOOL_EXECUTION_CONFIG = ExecutionConfig.builder() - .timeout(Duration.ofSeconds(30)) - .maxAttempts(1) - .build(); - @Override public String getAgentType() { return AGENT_TYPE; @@ -61,9 +48,9 @@ public Msg run(AgentRunContext context) { .name(AGENT_TYPE) .sysPrompt(defaultSystemPrompt(context.systemPrompt(), extensions.skillInstructions())) .model(context.model()) - .maxIters(DEFAULT_MAX_ITERS) - .modelExecutionConfig(DEFAULT_MODEL_EXECUTION_CONFIG) - .toolExecutionConfig(DEFAULT_TOOL_EXECUTION_CONFIG); + .maxIters(AgentRuntimeConstant.REACT_AGENT_DEFAULT_MAX_ITERS) + .modelExecutionConfig(AgentRuntimeConstant.DEFAULT_MODEL_EXECUTION_CONFIG) + .toolExecutionConfig(AgentRuntimeConstant.DEFAULT_TOOL_EXECUTION_CONFIG); if (extensions.toolkit() != null) { builder.toolkit(extensions.toolkit()); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java index de142e16c..8c705a0f9 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/template/ManagedAgentRegistry.java @@ -35,7 +35,7 @@ public ManagedAgentRegistry(List managedAgents) { public ManagedAgent getRequired() { ManagedAgent managedAgent = this.agentsByType.get(normalize(CommonAgent.AGENT_TYPE)); if (managedAgent == null) { - throw new IllegalStateException("ManagedAgent registry missing type: " + CommonAgent.AGENT_TYPE); + throw new IllegalStateException("ManagedAgent 注册表缺少类型:" + CommonAgent.AGENT_TYPE); } return managedAgent; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java similarity index 50% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java index 13823979a..de59baf91 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/aop/ExceptionAdvice.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolError.java @@ -13,22 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.cloud.ai.dataagent.aop; - -import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.ExceptionHandler; -import org.springframework.web.bind.annotation.RestControllerAdvice; - -@Slf4j -@RestControllerAdvice -public class ExceptionAdvice { - - @ExceptionHandler(Exception.class) - public ResponseEntity handleException(Exception e) { - log.error("An error occurred: ", e); - return ResponseEntity.internalServerError().body(ApiResponse.error("An error occurred: " + e.getMessage())); +package com.alibaba.cloud.ai.dataagent.agentscope.tool; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Builder; +import lombok.Value; + +@Value +@Builder +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolError { + + ToolErrorCode code; + + String message; + + Boolean retryable; + + public static ToolError of(ToolErrorCode code, String message) { + return ToolError.builder().code(code).message(message).build(); + } + + public static ToolError retryable(ToolErrorCode code, String message) { + return ToolError.builder().code(code).message(message).retryable(true).build(); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java similarity index 58% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java index b5e3258fb..5e732b4f5 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/EchoController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/ToolErrorCode.java @@ -13,26 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.cloud.ai.dataagent.controller; +package com.alibaba.cloud.ai.dataagent.agentscope.tool; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +public enum ToolErrorCode { -/** - * @author yingzi - * @since 2025/9/16 - */ -@RestController -@RequestMapping("/echo") -public class EchoController { - - /** - * 心跳检测 - */ - @GetMapping("ok") - public String ok() { - return "ok"; - } + INVALID_INPUT, + + UNSUPPORTED_ACTION, + + DATASOURCE_UNAVAILABLE, + + TABLE_NOT_VISIBLE, + + COLUMN_NOT_VISIBLE, + + EXECUTION_FAILED } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java index 18c473a9d..f338c1c3b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerAction.java @@ -40,7 +40,7 @@ public static DatasourceExplorerAction fromValue(String value) { return Arrays.stream(values()) .filter(action -> action.name().equalsIgnoreCase(value)) .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Unsupported datasource explorer action: " + value)); + .orElseThrow(() -> new IllegalArgumentException("不支持的数据源探索动作:" + value)); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java index 65ca23a79..eac1b7e4d 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerRequest.java @@ -15,7 +15,6 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.datasource; -import java.util.List; import lombok.Data; @Data @@ -27,8 +26,6 @@ public class DatasourceExplorerRequest { private String tableName; - private List tableNames; - private String sql; private Integer limit; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java index 69864f931..0f83c29ed 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerResult.java @@ -31,6 +31,8 @@ public class DatasourceExplorerResult { private String summary; + private Boolean searchReady; + @Builder.Default private List> tables = new ArrayList<>(); @@ -46,8 +48,22 @@ public class DatasourceExplorerResult { private String sql; @Builder.Default - private List nextSuggestedActions = new ArrayList<>(); + private List usedTables = new ArrayList<>(); + + @Builder.Default + private List usedColumns = new ArrayList<>(); + + @Builder.Default + private List> relationEvidence = new ArrayList<>(); + + @Builder.Default + private List toolDecisionReasons = new ArrayList<>(); + + @Builder.Default + private List resultScopeDetails = new ArrayList<>(); + + private String resultScope; - private boolean truncated; + private String decisionReason; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java index b207f0dac..370344f9c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerService.java @@ -17,13 +17,16 @@ import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ForeignKeyInfoBO; import com.alibaba.cloud.ai.dataagent.bo.schema.ResultSetBO; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; import com.alibaba.cloud.ai.dataagent.connector.accessor.AccessorFactory; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.Datasource; import com.alibaba.cloud.ai.dataagent.entity.LogicalRelation; +import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; import com.alibaba.cloud.ai.dataagent.service.schema.SchemaService; @@ -39,16 +42,36 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.regex.Pattern; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.AllColumns; +import net.sf.jsqlparser.statement.select.AllTableColumns; +import net.sf.jsqlparser.statement.select.FromItem; +import net.sf.jsqlparser.statement.select.Join; +import net.sf.jsqlparser.statement.select.LateralSubSelect; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.ParenthesedFromItem; +import net.sf.jsqlparser.statement.select.ParenthesedSelect; +import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; -import net.sf.jsqlparser.util.TablesNamesFinder; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.statement.select.TableFunction; +import net.sf.jsqlparser.statement.select.WithItem; import org.apache.commons.lang3.StringUtils; import org.springframework.ai.document.Document; +import org.springframework.lang.Nullable; import org.springframework.stereotype.Service; @Service @@ -66,9 +89,17 @@ public class DatasourceExplorerService { private static final Pattern FETCH_FIRST_PATTERN = Pattern .compile("(?i)\\bfetch\\s+first\\s+\\d+\\s+rows\\s+only\\b"); + private static final Pattern WHERE_PATTERN = Pattern.compile("(?i)\\bwhere\\b"); + + private static final Pattern GROUP_BY_PATTERN = Pattern.compile("(?i)\\bgroup\\s+by\\b"); + + private static final Pattern ORDER_BY_PATTERN = Pattern.compile("(?i)\\border\\s+by\\b"); + private static final TypeReference> STRING_LIST_TYPE = new TypeReference<>() { }; + private static final String HIDDEN_FIELD_INFERENCE_WARNING = " 请严格基于返回字段作答,不要根据邮箱前缀、ID、编码、别名等可见值推断任何未返回的隐藏字段。"; + private final AgentDatasourceService agentDatasourceService; private final DatasourceService datasourceService; @@ -79,140 +110,163 @@ public class DatasourceExplorerService { private final ObjectMapper objectMapper; + private final AnswerTraceExplainStore answerTraceExplainStore; + public DatasourceExplorerResult execute(String agentId, DatasourceExplorerRequest request) throws Exception { + return execute(agentId, request, null); + } + + public DatasourceExplorerResult execute(String agentId, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) throws Exception { if (request == null || request.getAction() == null) { - throw new IllegalArgumentException("Datasource explorer request.action 不能为空"); + throw new IllegalArgumentException("数据源探索请求必须提供 action"); } ExplorerContext context = resolveContext(agentId); return switch (request.getAction()) { - case LIST_TABLES -> listTables(context, request); - case FIND_TABLES -> findTables(context, request); - case GET_TABLE_SCHEMA -> getTableSchema(context, request); - case GET_RELATED_TABLES -> getRelatedTables(context, request); - case PREVIEW_ROWS -> previewRows(context, request); - case SEARCH -> search(context, request); + case LIST_TABLES -> listTables(context, request, graphRequest); + case FIND_TABLES -> findTables(context, request, graphRequest); + case GET_TABLE_SCHEMA -> getTableSchema(context, request, graphRequest); + case GET_RELATED_TABLES -> getRelatedTables(context, request, graphRequest); + case PREVIEW_ROWS -> previewRows(context, request, graphRequest); + case SEARCH -> search(context, request, graphRequest); }; } - private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceExplorerRequest request) { + private DatasourceExplorerResult listTables(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) { int limit = normalizeLimit(request.getLimit()); Map tableDocumentMap = loadTableDocumentMap(context, context.visibleTables()); List> tables = context.visibleTables() .stream() .sorted(String.CASE_INSENSITIVE_ORDER) - .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables())) + .map(tableName -> toTableEntry(context, tableName, tableDocumentMap.get(normalizeTableName(tableName)), + filterRelations(context, tableName))) .limit(limit) .toList(); - return baseResult(context, DatasourceExplorerAction.LIST_TABLES, tables.size() + " tables available") + return capture(baseResult(context, DatasourceExplorerAction.LIST_TABLES, "共发现 %d 张可见表".formatted(tables.size())) .tables(tables) - .nextSuggestedActions(List.of("get_table_schema", "preview_rows", "find_tables")) - .truncated(context.visibleTables().size() > limit) - .build(); + .searchReady(!tables.isEmpty()) + .build(), graphRequest); } - private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceExplorerRequest request) { + private DatasourceExplorerResult findTables(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) { int limit = normalizeLimit(request.getLimit()); String query = StringUtils.trimToEmpty(request.getQuery()).toLowerCase(Locale.ROOT); Map tableDocumentMap = loadTableDocumentMap(context, context.visibleTables()); List> matchedTables = context.visibleTables() .stream() - .map(tableName -> toTableEntry(tableName, tableDocumentMap.get(normalizeTableName(tableName)), - context.explicitSelectedTables())) + .map(tableName -> toTableEntry(context, tableName, tableDocumentMap.get(normalizeTableName(tableName)), + filterRelations(context, tableName))) .filter(table -> query.isEmpty() || containsQuery(table, query)) .limit(limit) .toList(); - String summary = query.isEmpty() ? "Returned visible tables without query filter" - : "Matched %d tables for query '%s'".formatted(matchedTables.size(), request.getQuery()); - return baseResult(context, DatasourceExplorerAction.FIND_TABLES, summary).tables(matchedTables) - .nextSuggestedActions(List.of("get_table_schema", "preview_rows")) - .truncated(matchedTables.size() >= limit) - .build(); - } - - private DatasourceExplorerResult getTableSchema(ExplorerContext context, DatasourceExplorerRequest request) - throws Exception { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); + String summary = query.isEmpty() ? "未提供筛选词,返回当前可见表列表" + : "针对关键词“%s”匹配到 %d 张表".formatted(request.getQuery(), matchedTables.size()); + return capture(baseResult(context, DatasourceExplorerAction.FIND_TABLES, summary).tables(matchedTables) + .searchReady(!matchedTables.isEmpty()) + .build(), graphRequest); + } + + private DatasourceExplorerResult getTableSchema(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) throws Exception { + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); List columns = context.accessor() .showColumns(context.dbConfig(), DbQueryParameter.from(context.dbConfig()) .setSchema(context.dbConfig().getSchema()) .setTable(tableName)); Document tableDocument = loadTableDocumentMap(context, List.of(tableName)).get(normalizeTableName(tableName)); - List> columnEntries = columns.stream().map(this::toColumnEntry).toList(); - List> relationEntries = filterRelations(context.logicalRelations(), tableName).stream() - .map(this::toRelationEntry) + Map columnDocumentMap = loadColumnDocumentMap(context, tableName); + List> columnEntries = applyVisibleColumnFilter(context, tableName, columns).stream() + .map(column -> toColumnEntry(column, columnDocumentMap.get(normalizeColumnName(column.getName())))) .toList(); - Map tableEntry = toTableEntry(tableName, tableDocument, context.explicitSelectedTables()); - return baseResult(context, DatasourceExplorerAction.GET_TABLE_SCHEMA, - "Loaded schema for table '%s'".formatted(tableName)) - .tables(List.of(tableEntry)) - .columns(columnEntries) - .relations(relationEntries) - .nextSuggestedActions(List.of("preview_rows", "search", "get_related_tables")) - .build(); + List relations = filterRelations(context, tableName); + List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); + Map tableEntry = toTableEntry(context, tableName, tableDocument, relations); + return capture( + baseResult(context, DatasourceExplorerAction.GET_TABLE_SCHEMA, "已加载表“%s”的结构信息".formatted(tableName)) + .tables(List.of(tableEntry)) + .columns(columnEntries) + .relations(relationEntries) + .searchReady(true) + .build(), + graphRequest); } - private DatasourceExplorerResult getRelatedTables(ExplorerContext context, DatasourceExplorerRequest request) { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); - List relations = filterRelations(context.logicalRelations(), tableName); + private DatasourceExplorerResult getRelatedTables(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) { + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); + List relations = filterRelations(context, tableName); List> relationEntries = relations.stream().map(this::toRelationEntry).toList(); Set relatedTables = relations.stream() - .flatMap(relation -> Arrays - .stream(new String[] { relation.getSourceTableName(), relation.getTargetTableName() })) + .flatMap(relation -> Arrays.stream(new String[] { relation.sourceTable(), relation.targetTable() })) .filter(candidate -> !normalizeTableName(candidate).equals(normalizeTableName(tableName))) .filter(candidate -> context.visibleTableNameSet().contains(normalizeTableName(candidate))) .collect(Collectors.toCollection(LinkedHashSet::new)); Map tableDocumentMap = loadTableDocumentMap(context, new ArrayList<>(relatedTables)); List> tableEntries = relatedTables.stream() - .map(relatedTable -> toTableEntry(relatedTable, tableDocumentMap.get(normalizeTableName(relatedTable)), - context.explicitSelectedTables())) + .map(relatedTable -> toTableEntry(context, relatedTable, + tableDocumentMap.get(normalizeTableName(relatedTable)), filterRelations(context, relatedTable))) .toList(); - return baseResult(context, DatasourceExplorerAction.GET_RELATED_TABLES, - "Found %d related tables for '%s'".formatted(tableEntries.size(), tableName)) + return capture(baseResult(context, DatasourceExplorerAction.GET_RELATED_TABLES, + "表“%s”共找到 %d 张关联表".formatted(tableName, tableEntries.size())) .tables(tableEntries) .relations(relationEntries) - .nextSuggestedActions(List.of("get_table_schema", "preview_rows")) - .build(); + .searchReady(!relationEntries.isEmpty()) + .build(), graphRequest); } - private DatasourceExplorerResult previewRows(ExplorerContext context, DatasourceExplorerRequest request) - throws Exception { - String tableName = requireSingleTableName(request); - assertVisibleTable(context, tableName); + private DatasourceExplorerResult previewRows(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) throws Exception { + String tableName = resolveVisibleTableName(context, requireSingleTableName(request)); int limit = normalizeLimit(request.getLimit()); - String sql = SqlUtil.buildSelectSql(context.dbConfig().getDialectType(), tableName, "*", limit); + String sql = SqlUtil.buildSelectSql(context.dbConfig().getDialectType(), + SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), tableName), + resolvePreviewColumnSelection(context, tableName), limit); ResultSetBO resultSet = executeSql(context, sql); - return baseResult(context, DatasourceExplorerAction.PREVIEW_ROWS, - "Previewed %d rows from '%s'".formatted(resultSet.getData().size(), tableName)) + return capture(baseResult(context, DatasourceExplorerAction.PREVIEW_ROWS, + ("已预览表“%s”的 %d 行数据".formatted(tableName, resultSet.getData().size())) + HIDDEN_FIELD_INFERENCE_WARNING) .tables(List.of(Map.of("name", tableName))) .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) .sql(sql) - .nextSuggestedActions(List.of("get_table_schema", "search")) - .truncated(resultSet.getData().size() >= limit) - .build(); + .usedTables(List.of(tableName)) + .usedColumns(toExplainColumns(resultSet)) + .resultScope("预览结果仅包含当前可见字段,并受 limit 限制。") + .resultScopeDetails(buildPreviewResultScopeDetails(tableName, resultSet, limit)) + .decisionReason("当前通过 PREVIEW_ROWS 直接预览单表数据,用于快速确认表内容。") + .toolDecisionReasons(buildPreviewDecisionReasons(tableName, limit)) + .searchReady(true) + .build(), graphRequest); } - private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplorerRequest request) - throws Exception { + private DatasourceExplorerResult search(ExplorerContext context, DatasourceExplorerRequest request, + @Nullable AgentRequest graphRequest) throws Exception { String rawSql = StringUtils.trimToNull(request.getSql()); if (rawSql == null) { throw new IllegalArgumentException("search action 必须提供 sql"); } int limit = normalizeLimit(request.getLimit()); - String guardedSql = guardReadonlySql(context, rawSql, limit); - ResultSetBO resultSet = executeSql(context, guardedSql); - return baseResult(context, DatasourceExplorerAction.SEARCH, - "Executed readonly search and returned %d rows".formatted(resultSet.getData().size())) + SqlGuardedQuery guardedQuery = guardReadonlySql(context, rawSql, limit); + ResultSetBO resultSet = filterResultSet(executeSql(context, guardedQuery.sql()), guardedQuery); + List> relationEvidence = collectRelationEvidence(context, guardedQuery.referencedTables()); + List usedTables = toExplainTables(guardedQuery.referencedTables(), context.visibleTablesByName()); + List usedColumns = toExplainColumns(resultSet); + return capture(baseResult(context, DatasourceExplorerAction.SEARCH, + ("已执行只读查询,返回 %d 行结果".formatted(resultSet.getData().size())) + HIDDEN_FIELD_INFERENCE_WARNING) .columns(toColumnHeaders(resultSet)) .rows(toRows(resultSet)) - .sql(guardedSql) - .nextSuggestedActions(List.of("get_table_schema", "preview_rows", "find_tables")) - .truncated(resultSet.getData().size() >= limit) - .build(); + .sql(guardedQuery.sql()) + .usedTables(usedTables) + .usedColumns(usedColumns) + .relationEvidence(relationEvidence) + .resultScope(buildResultScopeSummary(resultSet, limit)) + .resultScopeDetails(buildSearchResultScopeDetails(resultSet, limit, guardedQuery.allowedResultHeaders())) + .decisionReason(buildDecisionReason(guardedQuery.referencedTables(), relationEvidence)) + .toolDecisionReasons(buildSearchDecisionReasons(usedTables, relationEvidence, limit)) + .searchReady(true) + .build(), graphRequest); } private ExplorerContext resolveContext(String agentId) throws Exception { @@ -221,7 +275,7 @@ private ExplorerContext resolveContext(String agentId) throws Exception { Datasource datasource = agentDatasource.getDatasource() != null ? agentDatasource.getDatasource() : datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); if (datasource == null) { - throw new IllegalStateException("Active datasource not found for agent " + agentId); + throw new IllegalStateException("当前 Agent 未找到活动数据源:" + agentId); } DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); Accessor accessor = accessorFactory.getAccessorByDbConfig(dbConfig); @@ -229,18 +283,47 @@ private ExplorerContext resolveContext(String agentId) throws Exception { : agentDatasource.getSelectTables(); List visibleTables = explicitSelectedTables.isEmpty() ? datasourceService.getDatasourceTables(datasource.getId()) : explicitSelectedTables; + Map> visibleTablesByName = indexTablesByIdentity(visibleTables); + Map> visibleTablesByLeafName = indexTablesByLeafName(visibleTables); Set visibleTableNameSet = visibleTables.stream() .map(this::normalizeTableName) .collect(Collectors.toCollection(LinkedHashSet::new)); + Map> visibleColumnsByTable = buildVisibleColumnsByTable(agentDatasource, + visibleTablesByName, visibleTablesByLeafName); + Map> visibleColumnNameSetByTable = visibleColumnsByTable.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, + entry -> entry.getValue() + .stream() + .map(this::normalizeColumnName) + .collect(Collectors.toCollection(LinkedHashSet::new)), + (left, right) -> left, LinkedHashMap::new)); List logicalRelations = datasourceService.getLogicalRelations(datasource.getId()); + List physicalRelations = loadPhysicalRelations(accessor, dbConfig, visibleTables); + List unifiedRelations = buildUnifiedRelations(visibleTablesByName, visibleTablesByLeafName, + physicalRelations, logicalRelations == null ? List.of() : logicalRelations); return new ExplorerContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), - Set.copyOf(visibleTableNameSet), List.copyOf(explicitSelectedTables), - logicalRelations == null ? List.of() : List.copyOf(logicalRelations)); + Set.copyOf(visibleTableNameSet), toImmutableListIndex(visibleTablesByName), + toImmutableListIndex(visibleTablesByLeafName), List.copyOf(explicitSelectedTables), + Map.copyOf(visibleColumnsByTable), toImmutableSetIndex(visibleColumnNameSetByTable), + Set.copyOf(visibleColumnsByTable.keySet()), List.copyOf(unifiedRelations), + indexRelationsByTable(unifiedRelations)); + } + + private List loadPhysicalRelations(Accessor accessor, DbConfigBO dbConfig, List tables) { + try { + List foreignKeys = accessor.showForeignKeys(dbConfig, + DbQueryParameter.from(dbConfig).setSchema(dbConfig.getSchema()).setTables(tables)); + return foreignKeys == null ? List.of() : foreignKeys; + } + catch (Exception ex) { + return List.of(); + } } private Long parseAgentId(String agentId) { if (!StringUtils.isNumeric(agentId)) { - throw new IllegalArgumentException("Datasource explorer 当前仅支持数值型 agentId"); + throw new IllegalArgumentException("数据源探索当前仅支持数值型 agentId"); } return Long.valueOf(agentId); } @@ -264,25 +347,17 @@ private ResultSetBO executeSql(ExplorerContext context, String sql) throws Excep return resultSet; } - private String guardReadonlySql(ExplorerContext context, String rawSql, int limit) { + private SqlGuardedQuery guardReadonlySql(ExplorerContext context, String rawSql, int limit) { String compactSql = stripTrailingSemicolons(rawSql); if (compactSql.isEmpty()) { throw new IllegalArgumentException("SQL 不能为空"); } Statement statement = parseSingleSelectStatement(compactSql); - Set referencedTables = extractReferencedTables(statement); - if (!referencedTables.isEmpty()) { - List forbiddenTables = referencedTables.stream() - .filter(table -> !context.visibleTableNameSet().contains(normalizeTableName(table))) - .toList(); - if (!forbiddenTables.isEmpty()) { - throw new IllegalArgumentException("SQL 引用了当前 Agent 不可见的表: " + forbiddenTables); - } - } - if (hasLimit(compactSql)) { - return compactSql; - } - return wrapLimitSql(context.dbConfig().getDialectType(), compactSql, limit); + SqlValidationResult validationResult = new SqlColumnAccessValidator(context).validate((Select) statement); + String guardedSql = hasLimit(compactSql) ? compactSql + : wrapLimitSql(context.dbConfig().getDialectType(), compactSql, limit); + return new SqlGuardedQuery(guardedSql, validationResult.referencedTables(), + validationResult.allowedResultHeaders()); } private String stripTrailingSemicolons(String sql) { @@ -319,18 +394,6 @@ private Statement parseSingleSelectStatement(String sql) { return statement; } - private Set extractReferencedTables(Statement statement) { - try { - return new TablesNamesFinder().getTableList(statement) - .stream() - .map(this::normalizeTableName) - .collect(Collectors.toCollection(LinkedHashSet::new)); - } - catch (Exception ex) { - throw new IllegalArgumentException("无法从 SQL 中提取引用表", ex); - } - } - private String wrapLimitSql(String dialectType, String sql, int limit) { String normalizedDialect = StringUtils.defaultString(dialectType).toLowerCase(Locale.ROOT); if (normalizedDialect.contains("sqlserver") || normalizedDialect.contains("sql_server")) { @@ -342,19 +405,10 @@ private String wrapLimitSql(String dialectType, String sql, int limit) { return "SELECT * FROM (%s) dataagent_safe_limit LIMIT %d".formatted(sql, limit); } - private void assertVisibleTable(ExplorerContext context, String tableName) { - if (!context.visibleTableNameSet().contains(normalizeTableName(tableName))) { - throw new IllegalArgumentException("Table '%s' is not visible for current agent".formatted(tableName)); - } - } - private String requireSingleTableName(DatasourceExplorerRequest request) { if (StringUtils.isNotBlank(request.getTableName())) { return request.getTableName().trim(); } - if (request.getTableNames() != null && request.getTableNames().size() == 1) { - return StringUtils.trimToEmpty(request.getTableNames().get(0)); - } throw new IllegalArgumentException("当前 action 必须提供 tableName"); } @@ -370,7 +424,9 @@ private Map loadTableDocumentMap(ExplorerContext context, List return Collections.emptyMap(); } try { - return schemaService.getTableDocuments(context.datasource().getId(), tableNames) + return schemaService + .getTableDocuments(context.agentDatasource().getAgentId().toString(), context.datasource().getId(), + tableNames) .stream() .collect(Collectors.toMap(doc -> normalizeTableName(String.valueOf(doc.getMetadata().get("name"))), doc -> doc, (left, right) -> left, LinkedHashMap::new)); @@ -380,32 +436,70 @@ private Map loadTableDocumentMap(ExplorerContext context, List } } - private Map toTableEntry(String tableName, Document tableDocument, - List explicitSelectedTables) { + private Map loadColumnDocumentMap(ExplorerContext context, String tableName) { + try { + return schemaService + .getColumnDocumentsByTableName(context.agentDatasource().getAgentId().toString(), + context.datasource().getId(), List.of(tableName)) + .stream() + .collect(Collectors.toMap(doc -> normalizeColumnName(String.valueOf(doc.getMetadata().get("name"))), + doc -> doc, (left, right) -> left, LinkedHashMap::new)); + } + catch (Exception ex) { + return Collections.emptyMap(); + } + } + + private Map toTableEntry(ExplorerContext context, String tableName, Document tableDocument, + List relations) { Map tableEntry = new LinkedHashMap<>(); tableEntry.put("name", tableName); - tableEntry.put("selected", explicitSelectedTables.isEmpty() || explicitSelectedTables.stream() - .anyMatch(candidate -> normalizeTableName(candidate).equals(normalizeTableName(tableName)))); + tableEntry.put("selected", isSelectedTable(context, tableName)); + String unifiedForeignKeys = summarizeRelations(relations); if (tableDocument != null) { tableEntry.put("schema", tableDocument.getMetadata().getOrDefault("schema", "")); tableEntry.put("description", tableDocument.getMetadata().getOrDefault("description", "")); tableEntry.put("primaryKeys", tableDocument.getMetadata().getOrDefault("primaryKey", List.of())); - tableEntry.put("foreignKeys", tableDocument.getMetadata().getOrDefault("foreignKey", "")); + tableEntry.put("foreignKeys", StringUtils.defaultIfBlank(unifiedForeignKeys, + String.valueOf(tableDocument.getMetadata().getOrDefault("foreignKey", "")))); + } + else if (StringUtils.isNotBlank(unifiedForeignKeys)) { + tableEntry.put("foreignKeys", unifiedForeignKeys); } return tableEntry; } - private Map toColumnEntry(ColumnInfoBO columnInfo) { + private Map toColumnEntry(ColumnInfoBO columnInfo, Document columnDocument) { Map columnEntry = new LinkedHashMap<>(); columnEntry.put("name", columnInfo.getName()); columnEntry.put("type", columnInfo.getType()); - columnEntry.put("description", StringUtils.defaultString(columnInfo.getDescription())); + columnEntry.put("description", resolveColumnDescription(columnInfo, columnDocument)); columnEntry.put("primary", columnInfo.isPrimary()); columnEntry.put("notnull", columnInfo.isNotnull()); - columnEntry.put("samples", parseSamples(columnInfo.getSamples())); + columnEntry.put("samples", resolveColumnSamples(columnInfo, columnDocument)); return columnEntry; } + private String resolveColumnDescription(ColumnInfoBO columnInfo, Document columnDocument) { + if (StringUtils.isNotBlank(columnInfo.getDescription())) { + return columnInfo.getDescription(); + } + if (columnDocument == null) { + return StringUtils.EMPTY; + } + return String.valueOf(columnDocument.getMetadata().getOrDefault("description", "")); + } + + private List resolveColumnSamples(ColumnInfoBO columnInfo, Document columnDocument) { + if (StringUtils.isNotBlank(columnInfo.getSamples())) { + return parseSamples(columnInfo.getSamples()); + } + if (columnDocument == null) { + return List.of(); + } + return parseSamples(String.valueOf(columnDocument.getMetadata().getOrDefault("samples", ""))); + } + private List parseSamples(String samples) { if (StringUtils.isBlank(samples)) { return List.of(); @@ -418,30 +512,250 @@ private List parseSamples(String samples) { } } - private List filterRelations(List logicalRelations, String tableName) { - String normalizedTableName = normalizeTableName(tableName); - return logicalRelations.stream() - .filter(relation -> normalizedTableName.equals(normalizeTableName(relation.getSourceTableName())) - || normalizedTableName.equals(normalizeTableName(relation.getTargetTableName()))) - .sorted(Comparator.comparing(relation -> StringUtils.defaultString(relation.getSourceTableName()))) + private List filterRelations(ExplorerContext context, String tableName) { + return context.relationsByTable() + .getOrDefault(normalizeTableName(tableName), List.of()) + .stream() + .filter(relation -> isRelationVisible(context, relation)) .toList(); } - private Map toRelationEntry(LogicalRelation relation) { + private Map toRelationEntry(UnifiedRelation relation) { Map relationEntry = new LinkedHashMap<>(); - relationEntry.put("sourceTable", relation.getSourceTableName()); - relationEntry.put("sourceColumn", relation.getSourceColumnName()); - relationEntry.put("targetTable", relation.getTargetTableName()); - relationEntry.put("targetColumn", relation.getTargetColumnName()); - relationEntry.put("relationType", relation.getRelationType()); - relationEntry.put("description", relation.getDescription()); + relationEntry.put("sourceTable", relation.sourceTable()); + relationEntry.put("sourceColumn", relation.sourceColumn()); + relationEntry.put("targetTable", relation.targetTable()); + relationEntry.put("targetColumn", relation.targetColumn()); + relationEntry.put("relationType", relation.relationType()); + relationEntry.put("description", relation.description()); + relationEntry.put("sourceType", relation.sourceType()); + relationEntry.put("virtual", relation.virtual()); + relationEntry.put("declaredInDatabase", relation.declaredInDatabase()); return relationEntry; } + private List buildUnifiedRelations(Map> visibleTablesByName, + Map> visibleTablesByLeafName, List physicalRelations, + List logicalRelations) { + Map relationMap = new LinkedHashMap<>(); + for (ForeignKeyInfoBO physicalRelation : physicalRelations) { + UnifiedRelation relation = canonicalizeRelation(visibleTablesByName, visibleTablesByLeafName, + toUnifiedRelation(physicalRelation)); + if (relation != null) { + mergeRelation(relationMap, relation); + } + } + for (LogicalRelation logicalRelation : logicalRelations) { + UnifiedRelation relation = canonicalizeRelation(visibleTablesByName, visibleTablesByLeafName, + toUnifiedRelation(logicalRelation)); + if (relation != null) { + mergeRelation(relationMap, relation); + } + } + return relationMap.values() + .stream() + .sorted(Comparator.comparing((UnifiedRelation relation) -> normalizeTableName(relation.sourceTable())) + .thenComparing(relation -> StringUtils.defaultString(relation.sourceColumn())) + .thenComparing(relation -> normalizeTableName(relation.targetTable())) + .thenComparing(relation -> StringUtils.defaultString(relation.targetColumn()))) + .toList(); + } + + private UnifiedRelation toUnifiedRelation(ForeignKeyInfoBO relation) { + return new UnifiedRelation(relation.getTable(), relation.getColumn(), relation.getReferencedTable(), + relation.getReferencedColumn(), StringUtils.EMPTY, StringUtils.EMPTY, "physical", false, true); + } + + private UnifiedRelation toUnifiedRelation(LogicalRelation relation) { + return new UnifiedRelation(relation.getSourceTableName(), relation.getSourceColumnName(), + relation.getTargetTableName(), relation.getTargetColumnName(), + StringUtils.defaultString(relation.getRelationType()), + StringUtils.defaultString(relation.getDescription()), "logical", true, false); + } + + private UnifiedRelation canonicalizeRelation(Map> visibleTablesByName, + Map> visibleTablesByLeafName, UnifiedRelation relation) { + Optional sourceTable = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + relation.sourceTable(), true); + Optional targetTable = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + relation.targetTable(), true); + if (sourceTable.isEmpty() || targetTable.isEmpty()) { + return null; + } + return new UnifiedRelation(sourceTable.get(), relation.sourceColumn(), targetTable.get(), + relation.targetColumn(), relation.relationType(), relation.description(), relation.sourceType(), + relation.virtual(), relation.declaredInDatabase()); + } + + private void mergeRelation(Map relationMap, UnifiedRelation incoming) { + String relationKey = buildRelationKey(incoming); + UnifiedRelation existing = relationMap.get(relationKey); + if (existing == null) { + relationMap.put(relationKey, incoming); + return; + } + if (existing.declaredInDatabase() && !incoming.declaredInDatabase()) { + relationMap.put(relationKey, mergeRelation(existing, incoming)); + return; + } + if (!existing.declaredInDatabase() && incoming.declaredInDatabase()) { + relationMap.put(relationKey, mergeRelation(incoming, existing)); + return; + } + relationMap.put(relationKey, mergeRelation(existing, incoming)); + } + + private UnifiedRelation mergeRelation(UnifiedRelation preferred, UnifiedRelation supplement) { + return new UnifiedRelation(preferred.sourceTable(), preferred.sourceColumn(), preferred.targetTable(), + preferred.targetColumn(), + StringUtils.firstNonBlank(preferred.relationType(), supplement.relationType()), + StringUtils.firstNonBlank(preferred.description(), supplement.description()), preferred.sourceType(), + preferred.virtual(), preferred.declaredInDatabase()); + } + + private String buildRelationKey(UnifiedRelation relation) { + return normalizeTableName(relation.sourceTable()) + "|" + StringUtils.defaultString(relation.sourceColumn()) + + "|" + normalizeTableName(relation.targetTable()) + "|" + + StringUtils.defaultString(relation.targetColumn()); + } + + private Map> indexRelationsByTable(List relations) { + Map> relationIndex = new LinkedHashMap<>(); + for (UnifiedRelation relation : relations) { + relationIndex.computeIfAbsent(normalizeTableName(relation.sourceTable()), key -> new ArrayList<>()) + .add(relation); + String targetKey = normalizeTableName(relation.targetTable()); + if (!targetKey.equals(normalizeTableName(relation.sourceTable()))) { + relationIndex.computeIfAbsent(targetKey, key -> new ArrayList<>()).add(relation); + } + } + Map> immutableIndex = new LinkedHashMap<>(); + relationIndex + .forEach((tableName, tableRelations) -> immutableIndex.put(tableName, List.copyOf(tableRelations))); + return Map.copyOf(immutableIndex); + } + + private String summarizeRelations(List relations) { + return relations.stream() + .map(relation -> relation.sourceTable() + "." + relation.sourceColumn() + "=" + relation.targetTable() + "." + + relation.targetColumn()) + .distinct() + .collect(Collectors.joining("、")); + } + private List> toColumnHeaders(ResultSetBO resultSet) { return resultSet.getColumn().stream().map(column -> Map.of("name", column)).toList(); } + private List toExplainColumns(ResultSetBO resultSet) { + return Optional.ofNullable(resultSet.getColumn()) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + } + + private List toExplainTables(Set referencedTables, Map> visibleTablesByName) { + if (referencedTables == null || referencedTables.isEmpty()) { + return List.of(); + } + return referencedTables.stream().map(tableName -> { + List candidates = visibleTablesByName.getOrDefault(normalizeTableName(tableName), List.of()); + return candidates.isEmpty() ? tableName : candidates.get(0); + }).distinct().toList(); + } + + private List> collectRelationEvidence(ExplorerContext context, Set referencedTables) { + if (referencedTables == null || referencedTables.size() < 2) { + return List.of(); + } + Set normalizedReferencedTables = referencedTables.stream() + .map(this::normalizeTableName) + .collect(Collectors.toCollection(LinkedHashSet::new)); + return context.unifiedRelations() + .stream() + .filter(relation -> normalizedReferencedTables.contains(normalizeTableName(relation.sourceTable())) + && normalizedReferencedTables.contains(normalizeTableName(relation.targetTable()))) + .map(this::toRelationEntry) + .toList(); + } + + private String buildResultScopeSummary(ResultSetBO resultSet, int limit) { + int rowCount = resultSet.getData() == null ? 0 : resultSet.getData().size(); + int columnCount = resultSet.getColumn() == null ? 0 : resultSet.getColumn().size(); + if (rowCount >= limit) { + return "当前结果仅展示前 %d 行、%d 个返回字段;字段范围已经过可见性裁剪。".formatted(limit, columnCount); + } + return "当前结果展示 %d 行、%d 个返回字段;字段范围已经过可见性裁剪。".formatted(rowCount, columnCount); + } + + private String buildDecisionReason(Set referencedTables, List> relationEvidence) { + List usedTables = toExplainTables(referencedTables, Map.of()); + if (!relationEvidence.isEmpty()) { + return "本轮选择执行 SQL,是因为问题需要跨表查数;表间关联优先依据已配置的物理外键或逻辑关系。"; + } + if (usedTables.size() > 1) { + return "本轮选择执行 SQL,是因为问题需要联合多张表获取结构化结果。"; + } + if (usedTables.size() == 1) { + return "本轮选择执行 SQL,是因为问题需要直接从目标表提取结构化结果。"; + } + return "本轮选择执行 SQL,是因为问题需要结构化数据结果来支撑回答。"; + } + + private List buildPreviewDecisionReasons(String tableName, int limit) { + return List.of("本轮直接选择 PREVIEW_ROWS,是为了快速确认单表“%s”的样例数据。".formatted(tableName), + "预览查询会自动附带 limit=%d,避免一次返回过多行。".formatted(limit)); + } + + private List buildSearchDecisionReasons(List usedTables, List> relationEvidence, + int limit) { + List reasons = new ArrayList<>(); + reasons.add("本轮选择执行 SQL,是因为回答需要结构化结果来支撑结论。"); + if (!usedTables.isEmpty()) { + reasons.add("实际命中的表有:%s。".formatted(String.join("、", usedTables))); + } + if (usedTables.size() > 1) { + reasons.add("由于问题涉及多张表,系统需要联合查询后再生成答案。"); + } + if (!relationEvidence.isEmpty()) { + reasons.add("多表关联优先依据已配置的物理外键或逻辑关系,而不是临时猜测关联条件。"); + } + reasons.add("查询结果默认限制为最多 %d 行,避免一次返回过多数据。".formatted(limit)); + return List.copyOf(reasons); + } + + private List buildPreviewResultScopeDetails(String tableName, ResultSetBO resultSet, int limit) { + List details = new ArrayList<>(); + details.add("当前展示的是表“%s”的预览结果,不代表完整数据集。".formatted(tableName)); + details.add("本次共返回 %d 行、%d 个字段。".formatted(resultSet.getData() == null ? 0 : resultSet.getData().size(), + resultSet.getColumn() == null ? 0 : resultSet.getColumn().size())); + details.add("预览结果仅包含当前 Agent 可见字段。"); + details.add("预览查询使用了 limit=%d。".formatted(limit)); + details.add("禁止根据未返回字段推断隐藏信息。"); + return List.copyOf(details); + } + + private List buildSearchResultScopeDetails(ResultSetBO resultSet, int limit, Set allowedHeaders) { + List details = new ArrayList<>(); + int rowCount = resultSet.getData() == null ? 0 : resultSet.getData().size(); + int columnCount = resultSet.getColumn() == null ? 0 : resultSet.getColumn().size(); + details.add("当前结果共返回 %d 行、%d 个字段。".formatted(rowCount, columnCount)); + details.add("查询结果最多展示 %d 行。".formatted(limit)); + if (allowedHeaders != null && !allowedHeaders.isEmpty()) { + details.add("最终仅保留 SQL select 列表中显式声明的返回字段。"); + details.add("允许返回的字段有:%s。".formatted(String.join("、", allowedHeaders))); + } + else { + details.add("当前 SQL 未触发额外的结果列过滤。"); + } + details.add("所有结果仍受字段可见性约束限制。"); + details.add("禁止根据未返回字段推断隐藏信息。"); + return List.copyOf(details); + } + private List> toRows(ResultSetBO resultSet) { return resultSet.getData().stream().map(row -> { Map mappedRow = new LinkedHashMap<>(); @@ -450,6 +764,41 @@ private List> toRows(ResultSetBO resultSet) { }).toList(); } + private ResultSetBO filterResultSet(ResultSetBO resultSet, SqlGuardedQuery guardedQuery) { + Set allowedHeaders = guardedQuery.allowedResultHeaders(); + if (allowedHeaders == null || allowedHeaders.isEmpty()) { + return resultSet; + } + List originalColumns = Optional.ofNullable(resultSet.getColumn()).orElse(List.of()); + List keptIndexes = new ArrayList<>(); + List keptColumns = new ArrayList<>(); + for (int index = 0; index < originalColumns.size(); index++) { + String columnName = originalColumns.get(index); + if (allowedHeaders.contains(normalizeColumnName(columnName))) { + keptIndexes.add(index); + keptColumns.add(columnName); + } + } + if (keptIndexes.size() == originalColumns.size()) { + return resultSet; + } + List> filteredRows = Optional.ofNullable(resultSet.getData()) + .orElse(List.of()) + .stream() + .map(row -> { + Map filteredRow = new LinkedHashMap<>(); + for (Integer keptIndex : keptIndexes) { + String columnName = originalColumns.get(keptIndex); + filteredRow.put(columnName, row.get(columnName)); + } + return filteredRow; + }) + .toList(); + resultSet.setColumn(keptColumns); + resultSet.setData(filteredRows); + return resultSet; + } + private boolean containsQuery(Map table, String query) { return table.values() .stream() @@ -459,18 +808,177 @@ private boolean containsQuery(Map table, String query) { .anyMatch(value -> value.contains(query)); } - private String normalizeTableName(String tableName) { - String normalized = StringUtils.trimToEmpty(tableName); + private List applyVisibleColumnFilter(ExplorerContext context, String tableName, + List columns) { + return Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(column -> isColumnVisible(context, tableName, column.getName())) + .toList(); + } + + private String resolvePreviewColumnSelection(ExplorerContext context, String tableName) { + List visibleColumns = context.visibleColumnsByTable().get(normalizeTableName(tableName)); + if (visibleColumns == null) { + return "*"; + } + if (visibleColumns.isEmpty()) { + throw new IllegalArgumentException("表 '%s' 当前没有可预览字段,请先调整字段级可见性配置".formatted(tableName)); + } + return visibleColumns.stream() + .map(columnName -> SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), columnName)) + .collect(Collectors.joining(", ")); + } + + private boolean isRelationVisible(ExplorerContext context, UnifiedRelation relation) { + return isColumnVisible(context, relation.sourceTable(), relation.sourceColumn()) + && isColumnVisible(context, relation.targetTable(), relation.targetColumn()); + } + + private boolean isColumnVisible(ExplorerContext context, String tableName, String columnName) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return true; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + return visibleColumns != null && visibleColumns.contains(normalizeColumnName(columnName)); + } + + private Map> buildVisibleColumnsByTable(AgentDatasource agentDatasource, + Map> visibleTablesByName, Map> visibleTablesByLeafName) { + Map> selectedColumns = Optional.ofNullable(agentDatasource.getSelectColumns()) + .orElse(Map.of()); + Map> visibleColumnsByTable = new LinkedHashMap<>(); + selectedColumns.forEach((tableName, columns) -> { + Optional resolvedTableName = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + tableName, true); + if (resolvedTableName.isEmpty()) { + return; + } + List sanitizedColumns = Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .collect(Collectors.toCollection(LinkedHashSet::new)) + .stream() + .toList(); + if (!sanitizedColumns.isEmpty()) { + visibleColumnsByTable.put(normalizeTableName(resolvedTableName.get()), List.copyOf(sanitizedColumns)); + } + }); + return visibleColumnsByTable; + } + + private Map> toImmutableListIndex(Map> source) { + Map> immutableIndex = new LinkedHashMap<>(); + source.forEach((key, value) -> immutableIndex.put(key, List.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private Map> toImmutableSetIndex(Map> source) { + Map> immutableIndex = new LinkedHashMap<>(); + source.forEach((key, value) -> immutableIndex.put(key, Set.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private Map> indexTablesByIdentity(List tableNames) { + return indexTables(tableNames, false); + } + + private Map> indexTablesByLeafName(List tableNames) { + return indexTables(tableNames, true); + } + + private Map> indexTables(List tableNames, boolean leafOnly) { + Map> index = new LinkedHashMap<>(); + for (String tableName : Optional.ofNullable(tableNames).orElse(List.of())) { + if (StringUtils.isBlank(tableName)) { + continue; + } + String normalizedKey = leafOnly ? normalizeTableLeafName(tableName) : normalizeTableName(tableName); + index.computeIfAbsent(normalizedKey, key -> new LinkedHashSet<>()).add(tableName); + } + Map> immutableIndex = new LinkedHashMap<>(); + index.forEach((key, value) -> immutableIndex.put(key, List.copyOf(value))); + return Map.copyOf(immutableIndex); + } + + private String resolveVisibleTableName(ExplorerContext context, String tableName) { + return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) + .orElseThrow(() -> buildInvisibleTableException(context, tableName)); + } + + private Optional findVisibleTableName(Map> visibleTablesByName, + Map> visibleTablesByLeafName, String tableName, boolean allowQualifiedFallback) { + String normalizedTableName = normalizeTableName(tableName); + List exactMatches = visibleTablesByName.getOrDefault(normalizedTableName, List.of()); + if (exactMatches.size() == 1) { + return Optional.of(exactMatches.get(0)); + } + if (exactMatches.size() > 1) { + throw new IllegalArgumentException( + "表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); + } + if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { + return Optional.empty(); + } + List leafMatches = visibleTablesByLeafName.getOrDefault(normalizeTableLeafName(tableName), List.of()); + if (leafMatches.size() == 1) { + return Optional.of(leafMatches.get(0)); + } + if (leafMatches.size() > 1) { + throw new IllegalArgumentException( + "表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); + } + return Optional.empty(); + } + + private IllegalArgumentException buildInvisibleTableException(ExplorerContext context, String tableName) { + return new IllegalArgumentException( + "表 '%s' 对当前 Agent 不可见。当前可见表:%s".formatted(tableName, String.join(", ", context.visibleTables()))); + } + + private boolean isSelectedTable(ExplorerContext context, String tableName) { + if (context.explicitSelectedTables().isEmpty()) { + return true; + } + String normalizedTableName = normalizeTableName(tableName); + return context.explicitSelectedTables() + .stream() + .map(this::normalizeTableName) + .anyMatch(normalizedTableName::equals); + } + + private boolean isQualifiedIdentifier(String value) { + return normalizeIdentifier(value).contains("."); + } + + private String normalizeIdentifier(String value) { + String normalized = StringUtils.trimToEmpty(value); normalized = StringUtils.removeStart(normalized, "`"); normalized = StringUtils.removeEnd(normalized, "`"); normalized = StringUtils.removeStart(normalized, "\""); normalized = StringUtils.removeEnd(normalized, "\""); normalized = StringUtils.removeStart(normalized, "["); normalized = StringUtils.removeEnd(normalized, "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeTableName(String tableName) { + return normalizeIdentifier(tableName); + } + + private String normalizeTableLeafName(String tableName) { + String normalized = normalizeIdentifier(tableName); if (normalized.contains(".")) { - normalized = normalized.substring(normalized.lastIndexOf('.') + 1); + return normalized.substring(normalized.lastIndexOf('.') + 1); } - return normalized.toLowerCase(Locale.ROOT); + return normalized; + } + + private String normalizeColumnName(String columnName) { + return normalizeTableLeafName(columnName); } private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(ExplorerContext context, @@ -481,9 +989,361 @@ private DatasourceExplorerResult.DatasourceExplorerResultBuilder baseResult(Expl .summary(summary); } + private DatasourceExplorerResult capture(DatasourceExplorerResult result, @Nullable AgentRequest graphRequest) { + if (graphRequest != null) { + answerTraceExplainStore.recordDatasourceResult(graphRequest, result); + } + else { + answerTraceExplainStore.recordDatasourceResult(result); + } + return result; + } + private record ExplorerContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, Accessor accessor, List visibleTables, Set visibleTableNameSet, - List explicitSelectedTables, List logicalRelations) { + Map> visibleTablesByName, Map> visibleTablesByLeafName, + List explicitSelectedTables, Map> visibleColumnsByTable, + Map> visibleColumnNameSetByTable, Set columnRestrictedTables, + List unifiedRelations, Map> relationsByTable) { + } + + private record UnifiedRelation(String sourceTable, String sourceColumn, String targetTable, String targetColumn, + String relationType, String description, String sourceType, boolean virtual, boolean declaredInDatabase) { + } + + private record SqlGuardedQuery(String sql, Set referencedTables, Set allowedResultHeaders) { + } + + private record SqlValidationResult(Set referencedTables, Set allowedResultHeaders) { + } + + private record SourceBinding(String referenceName, String tableName) { + + boolean isBaseTable() { + return StringUtils.isNotBlank(tableName); + } + } + + private record SelectScope(Map sourcesByReference, List baseSources) { + + SourceBinding resolve(String reference) { + return sourcesByReference.get(reference); + } + } + + private final class SqlColumnAccessValidator { + + private final ExplorerContext context; + + private SqlColumnAccessValidator(ExplorerContext context) { + this.context = context; + } + + private SqlValidationResult validate(Select select) { + Set referencedTables = new LinkedHashSet<>(); + validateSelect(select, new LinkedHashSet<>(), referencedTables); + return new SqlValidationResult(Set.copyOf(referencedTables), extractAllowedResultHeaders(select)); + } + + private void validateSelect(Select select, Set cteNames, Set referencedTables) { + Set nextCteNames = new LinkedHashSet<>(cteNames); + if (select.getWithItemsList() != null) { + for (WithItem withItem : select.getWithItemsList()) { + if (withItem.getAlias() != null && StringUtils.isNotBlank(withItem.getAlias().getName())) { + nextCteNames.add(normalizeTableName(withItem.getAlias().getName())); + } + } + for (WithItem withItem : select.getWithItemsList()) { + validateSelect(withItem.getSelect(), nextCteNames, referencedTables); + } + } + if (select instanceof PlainSelect plainSelect) { + validatePlainSelect(plainSelect, nextCteNames, referencedTables); + return; + } + if (select instanceof SetOperationList setOperationList) { + for (Select childSelect : Optional.ofNullable(setOperationList.getSelects()).orElse(List.of())) { + validateSelect(childSelect, nextCteNames, referencedTables); + } + return; + } + if (select instanceof ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), nextCteNames, referencedTables); + return; + } + throw new IllegalArgumentException("当前 SQL 包含暂不支持的查询结构: " + select.getClass().getSimpleName()); + } + + private void validatePlainSelect(PlainSelect plainSelect, Set cteNames, Set referencedTables) { + SelectScope scope = buildScope(plainSelect, cteNames, referencedTables); + Set selectAliases = extractSelectAliases(plainSelect.getSelectItems()); + for (SelectItem selectItem : Optional.ofNullable(plainSelect.getSelectItems()).orElse(List.of())) { + validateExpression(selectItem.getExpression(), scope, cteNames, "SELECT", Set.of()); + } + validateExpression(plainSelect.getWhere(), scope, cteNames, "WHERE", Set.of()); + validateExpression(plainSelect.getHaving(), scope, cteNames, "HAVING", selectAliases); + validateExpression(plainSelect.getQualify(), scope, cteNames, "QUALIFY", selectAliases); + if (plainSelect.getGroupBy() != null && plainSelect.getGroupBy().getGroupByExpressions() != null) { + for (Object groupByExpression : Optional + .ofNullable(plainSelect.getGroupBy().getGroupByExpressions().getExpressions()) + .orElse(List.of())) { + if (groupByExpression instanceof Expression expression) { + validateExpression(expression, scope, cteNames, "GROUP BY", selectAliases); + } + } + } + for (OrderByElement orderByElement : Optional.ofNullable(plainSelect.getOrderByElements()) + .orElse(List.of())) { + validateExpression(orderByElement.getExpression(), scope, cteNames, "ORDER BY", selectAliases); + } + } + + private SelectScope buildScope(PlainSelect plainSelect, Set cteNames, Set referencedTables) { + Map sourcesByReference = new LinkedHashMap<>(); + List baseSources = new ArrayList<>(); + addFromItemSources(plainSelect.getFromItem(), cteNames, referencedTables, sourcesByReference, baseSources); + for (Join join : Optional.ofNullable(plainSelect.getJoins()).orElse(List.of())) { + addFromItemSources(join.getRightItem(), cteNames, referencedTables, sourcesByReference, baseSources); + if (join.getUsingColumns() != null && !join.getUsingColumns().isEmpty()) { + throw new IllegalArgumentException( + "当前 SQL 使用了 JOIN ... USING 语法。字段级可见性校验要求改写成显式 ON alias.column = alias.column"); + } + for (Expression onExpression : Optional.ofNullable(join.getOnExpressions()).orElse(List.of())) { + validateExpression(onExpression, + new SelectScope(Map.copyOf(sourcesByReference), List.copyOf(baseSources)), cteNames, + "JOIN ON", Set.of()); + } + } + return new SelectScope(Map.copyOf(sourcesByReference), List.copyOf(baseSources)); + } + + private void addFromItemSources(FromItem fromItem, Set cteNames, Set referencedTables, + Map sourcesByReference, List baseSources) { + if (fromItem == null) { + return; + } + if (fromItem instanceof Table table) { + String normalizedTableReference = normalizeTableName(extractTableReference(table)); + String aliasName = table.getAlias() == null ? null : normalizeTableName(table.getAlias().getName()); + if (cteNames.contains(normalizedTableReference)) { + registerSource(sourcesByReference, + new SourceBinding(StringUtils.defaultIfBlank(aliasName, normalizedTableReference), null)); + return; + } + String resolvedTableName = resolveVisibleTableName(context, extractTableReference(table)); + referencedTables.add(normalizeTableName(resolvedTableName)); + SourceBinding sourceBinding = new SourceBinding( + StringUtils.defaultIfBlank(aliasName, normalizeTableName(resolvedTableName)), + resolvedTableName); + registerSource(sourcesByReference, sourceBinding); + registerSource(sourcesByReference, + new SourceBinding(normalizeTableName(resolvedTableName), resolvedTableName)); + String tableLeafName = normalizeTableLeafName(resolvedTableName); + if (!tableLeafName.equals(normalizeTableName(resolvedTableName))) { + registerSource(sourcesByReference, new SourceBinding(tableLeafName, resolvedTableName)); + } + baseSources.add(sourceBinding); + return; + } + if (fromItem instanceof LateralSubSelect lateralSubSelect) { + validateSelect(lateralSubSelect.getSelect(), cteNames, referencedTables); + registerDerivedSource(lateralSubSelect, sourcesByReference); + return; + } + if (fromItem instanceof ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), cteNames, referencedTables); + registerDerivedSource(parenthesedSelect, sourcesByReference); + return; + } + if (fromItem instanceof ParenthesedFromItem parenthesedFromItem) { + Map nestedSources = new LinkedHashMap<>(); + List nestedBaseSources = new ArrayList<>(); + addFromItemSources(parenthesedFromItem.getFromItem(), cteNames, referencedTables, nestedSources, + nestedBaseSources); + for (Join join : Optional.ofNullable(parenthesedFromItem.getJoins()).orElse(List.of())) { + addFromItemSources(join.getRightItem(), cteNames, referencedTables, nestedSources, + nestedBaseSources); + if (join.getUsingColumns() != null && !join.getUsingColumns().isEmpty()) { + throw new IllegalArgumentException( + "当前 SQL 使用了 JOIN ... USING 语法。字段级可见性校验要求改写成显式 ON alias.column = alias.column"); + } + for (Expression onExpression : Optional.ofNullable(join.getOnExpressions()).orElse(List.of())) { + validateExpression(onExpression, + new SelectScope(Map.copyOf(nestedSources), List.copyOf(nestedBaseSources)), cteNames, + "JOIN ON", Set.of()); + } + } + if (parenthesedFromItem.getAlias() != null + && StringUtils.isNotBlank(parenthesedFromItem.getAlias().getName())) { + registerDerivedSource(parenthesedFromItem, sourcesByReference); + } + else { + nestedSources.values().forEach(binding -> registerSource(sourcesByReference, binding)); + baseSources.addAll(nestedBaseSources); + } + return; + } + if (fromItem instanceof TableFunction tableFunction) { + registerDerivedSource(tableFunction, sourcesByReference); + return; + } + throw new IllegalArgumentException("当前 SQL 包含暂不支持的 FROM 结构: " + fromItem.getClass().getSimpleName()); + } + + private void registerDerivedSource(FromItem fromItem, Map sourcesByReference) { + if (fromItem.getAlias() == null || StringUtils.isBlank(fromItem.getAlias().getName())) { + return; + } + registerSource(sourcesByReference, + new SourceBinding(normalizeTableName(fromItem.getAlias().getName()), null)); + } + + private void registerSource(Map sourcesByReference, SourceBinding sourceBinding) { + SourceBinding existingBinding = sourcesByReference.get(sourceBinding.referenceName()); + if (existingBinding != null && !Objects.equals(existingBinding.tableName(), sourceBinding.tableName())) { + throw new IllegalArgumentException( + "Table reference '%s' is ambiguous in current SQL scope; please use aliases" + .formatted(sourceBinding.referenceName())); + } + sourcesByReference.put(sourceBinding.referenceName(), sourceBinding); + } + + private void validateExpression(Expression expression, SelectScope scope, Set cteNames, String clause, + Set allowedAliases) { + if (expression == null) { + return; + } + expression.accept(new ExpressionVisitorAdapter() { + @Override + public void visit(Function function) { + if (function.isAllColumns() && !"count".equalsIgnoreCase(function.getName())) { + throw new IllegalArgumentException("子句 %s 中检测到 %s(*)。字段级可见性校验禁止使用除 COUNT(*) 外的星号聚合,请显式列出字段" + .formatted(clause, StringUtils.defaultIfBlank(function.getName(), "function"))); + } + super.visit(function); + } + + @Override + public void visit(AllColumns allColumns) { + throw new IllegalArgumentException("子句 %s 中检测到 SELECT *。请改成显式列名,避免越权读取隐藏字段".formatted(clause)); + } + + @Override + public void visit(AllTableColumns allTableColumns) { + throw new IllegalArgumentException( + "子句 %s 中检测到 %s.*。请改成显式列名,避免越权读取隐藏字段".formatted(clause, allTableColumns.getTable())); + } + + @Override + public void visit(Column column) { + validateColumnReference(scope, clause, column, allowedAliases); + } + + @Override + public void visit(ParenthesedSelect parenthesedSelect) { + validateSelect(parenthesedSelect.getSelect(), cteNames, new LinkedHashSet<>()); + } + + @Override + public void visit(Select select) { + validateSelect(select, cteNames, new LinkedHashSet<>()); + } + }); + } + + private void validateColumnReference(SelectScope scope, String clause, Column column, + Set allowedAliases) { + String normalizedColumnName = normalizeColumnName(column.getColumnName()); + if (StringUtils.isBlank(normalizedColumnName)) { + throw new IllegalArgumentException("子句 %s 中存在无法识别的字段引用: %s".formatted(clause, column)); + } + Table table = column.getTable(); + String tableReference = table == null ? StringUtils.EMPTY + : normalizeTableName(extractTableReference(table)); + if (StringUtils.isBlank(tableReference)) { + if (allowedAliases.contains(normalizedColumnName)) { + return; + } + if (scope.baseSources().isEmpty()) { + return; + } + if (scope.baseSources().size() > 1) { + throw new IllegalArgumentException("子句 %s 中的字段 '%s' 没有带表前缀。当前 SQL 涉及多张基础表,无法安全判断字段归属,请改成 alias.%s" + .formatted(clause, column.getColumnName(), column.getColumnName())); + } + SourceBinding sourceBinding = scope.baseSources().get(0); + assertVisibleColumn(sourceBinding.tableName(), column.getColumnName(), clause, column.toString()); + return; + } + SourceBinding sourceBinding = scope.resolve(tableReference); + if (sourceBinding == null) { + throw new IllegalArgumentException( + "子句 %s 中引用了未知表/别名 '%s'。请检查 SQL 中的表别名是否和 FROM/JOIN 定义一致".formatted(clause, table.getName())); + } + if (!sourceBinding.isBaseTable()) { + return; + } + assertVisibleColumn(sourceBinding.tableName(), column.getColumnName(), clause, column.toString()); + } + + private void assertVisibleColumn(String tableName, String columnName, String clause, String expression) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + if (visibleColumns == null || !visibleColumns.contains(normalizeColumnName(columnName))) { + String visibleColumnSummary = Optional + .ofNullable(context.visibleColumnsByTable().get(normalizedTableName)) + .orElse(List.of()) + .stream() + .collect(Collectors.joining(", ")); + throw new IllegalArgumentException("子句 %s 中的字段引用 '%s' 不被允许。表 '%s' 已启用字段级可见性控制,仅允许字段: [%s]" + .formatted(clause, expression, tableName, visibleColumnSummary)); + } + } + + private Set extractSelectAliases(List> selectItems) { + return Optional.ofNullable(selectItems) + .orElse(List.of()) + .stream() + .map(SelectItem::getAlias) + .filter(alias -> alias != null && StringUtils.isNotBlank(alias.getName())) + .map(alias -> normalizeColumnName(alias.getName())) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + private Set extractAllowedResultHeaders(Select select) { + if (!(select instanceof PlainSelect plainSelect)) { + return null; + } + Set headers = new LinkedHashSet<>(); + for (SelectItem selectItem : Optional.ofNullable(plainSelect.getSelectItems()).orElse(List.of())) { + Expression expression = selectItem.getExpression(); + if (expression instanceof AllColumns || expression instanceof AllTableColumns) { + return null; + } + if (selectItem.getAlias() != null && StringUtils.isNotBlank(selectItem.getAlias().getName())) { + headers.add(normalizeColumnName(selectItem.getAlias().getName())); + continue; + } + if (expression instanceof Column column) { + headers.add(normalizeColumnName(column.getColumnName())); + continue; + } + return null; + } + return headers; + } + + private String extractTableReference(Table table) { + String fullyQualifiedName = table == null ? StringUtils.EMPTY : table.getFullyQualifiedName(); + if (StringUtils.isNotBlank(fullyQualifiedName)) { + return fullyQualifiedName; + } + return table == null ? StringUtils.EMPTY : table.getName(); + } + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java index c150727ad..9b2dac44f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/datasource/DatasourceExplorerToolProvider.java @@ -15,7 +15,11 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.datasource; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.Datasource; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; @@ -48,35 +52,26 @@ public class DatasourceExplorerToolProvider implements AgentScopedToolProvider { "PREVIEW_ROWS", "SEARCH" ], - "description": "探索动作。推荐顺序:LIST_TABLES -> GET_TABLE_SCHEMA -> PREVIEW_ROWS/SEARCH" + "description": "探索动作。" }, "query": { "type": "string", - "description": "用于 FIND_TABLES 的检索关键词" + "description": "FIND_TABLES 时必填,用于按关键词查找表。" }, "tableName": { "type": "string", - "description": "目标表名。GET_TABLE_SCHEMA / GET_RELATED_TABLES / PREVIEW_ROWS 必填" - }, - "tableNames": { - "type": "array", - "items": { - "type": "string" - }, - "description": "可选表名列表。当前版本主要使用单表" + "description": "GET_TABLE_SCHEMA、GET_RELATED_TABLES、PREVIEW_ROWS 时必填的目标表名。" }, "sql": { "type": "string", - "description": "SEARCH 动作需要的只读 SQL。仅允许 SELECT/WITH" + "description": "SEARCH 时必填的只读 SQL,仅允许 SELECT/WITH。" }, "limit": { "type": "integer", - "description": "返回行数上限,默认 20,最大 200" + "description": "可选返回上限,默认 20,最大 200。" } }, - "required": [ - "action" - ] + "required": ["action"] } """; @@ -132,13 +127,15 @@ private String buildDescription(Datasource datasource, AgentDatasource agentData String visibleTables = selectedTables.isEmpty() ? "当前未显式选表,将回退到数据源全部可见表" : "当前显式选表 %d 个:%s" .formatted(selectedTables.size(), String.join(", ", selectedTables.stream().limit(8).toList())); return """ - Unified explorer for datasource '%s' (%s). - Use this tool to inspect tables, inspect schema, preview rows, and execute readonly SQL search. - Constraints: - 1. Only the current agent's active datasource is visible. - 2. Only readonly SQL is allowed for SEARCH. - 3. Recommended call order: LIST_TABLES -> GET_TABLE_SCHEMA -> PREVIEW_ROWS -> SEARCH. - 4. %s + 数据源'%s'(%s)的统一探索工具。 + 可用于查看表列表、查找表、查看单表结构、查看关系、按需预览样例数据,以及执行只读 SQL 查询。 + 约束说明: + 1. 只能访问当前 Agent 的活动数据源。 + 2. SEARCH 仅允许执行只读 SQL。 + 3. 如果只需要定位表,优先使用 LIST_TABLES 或 FIND_TABLES。 + 4. 如果需要写 SQL,先获取表结构和关系,再决定是否执行 SEARCH。 + 5. PREVIEW_ROWS 不是默认前置动作,只有样例值会实质影响 SQL 写法时才使用。 + 6. %s """.formatted(datasource.getName(), datasource.getType(), visibleTables); } @@ -167,18 +164,55 @@ public ToolDefinition getToolDefinition() { @Override public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { try { DatasourceExplorerRequest request = objectMapper.readValue(toolInput, DatasourceExplorerRequest.class); - return objectMapper.writeValueAsString(datasourceExplorerService.execute(agentId, request)); + validateRequest(request); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper + .writeValueAsString(datasourceExplorerService.execute(agentId, request, agentRequest)); } catch (Exception ex) { - throw new IllegalStateException("Datasource explorer tool failed: " + ex.getMessage(), ex); + throw new IllegalStateException( + objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, "数据源探索工具执行失败:" + ex.getMessage())), + ex); } } - @Override - public String call(String toolInput, ToolContext toolContext) { - return call(toolInput); + private void validateRequest(DatasourceExplorerRequest request) { + if (request == null || request.getAction() == null) { + throw new IllegalArgumentException( + objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, "数据源探索工具需要 action 参数"))); + } + switch (request.getAction()) { + case FIND_TABLES -> requireText(request.getQuery(), "FIND_TABLES 需要 query 参数"); + case GET_TABLE_SCHEMA, GET_RELATED_TABLES, PREVIEW_ROWS -> + requireText(request.getTableName(), request.getAction().name() + " 需要 tableName 参数"); + case SEARCH -> requireText(request.getSql(), "SEARCH 需要 sql 参数"); + case LIST_TABLES -> { + } + default -> throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.UNSUPPORTED_ACTION, "不支持的数据源探索动作:" + request.getAction()))); + } + } + + private void requireText(String value, String message) { + if (!StringUtils.isNotBlank(value)) { + throw new IllegalArgumentException(objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, message))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; + } } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java index 13ddc6291..7c2cc99bc 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolProvider.java @@ -29,10 +29,10 @@ public class DomainBusinessKnowledgeToolProvider implements AgentScopedToolProvi private static final String TOOL_NAME = "domain_business_knowledge.search"; private static final String DESCRIPTION = """ - Search the current agent's recalled business terms, FAQs, QA entries, and embedded documents on demand. - Use this tool when you need internal business definitions, metric rules, SOPs, historical cases, or terminology clarification. - Call it only when the answer depends on domain knowledge instead of general reasoning. - Do not use this tool for database table names, column names, field types, enum values, schema relations, field comments, or other table-structure interpretation questions. Those should go to datasource explorer first and semantic_model.search if supplemental semantic hints are needed. + 按需检索当前 Agent 已召回的业务术语、FAQ、问答条目和嵌入文档。 + 当回答依赖内部业务定义、指标口径、SOP、历史案例或领域术语澄清时,才使用本工具。 + 只有当答案确实依赖领域知识,而不是通用推理或数据库物理结构本身时,才调用本工具。 + 不要把本工具用于数据库表名、列名、字段类型、枚举值、表关系、字段注释或其他表结构解释问题;这些问题应先交给数据源探索工具,如仍需补充语义,再考虑 `semantic_model.search`。 """; private final DomainBusinessKnowledgeToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java index 02a773cd9..dd0dd2cf2 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/knowledge/DomainBusinessKnowledgeToolSupport.java @@ -15,6 +15,10 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.knowledge; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService; import com.alibaba.cloud.ai.dataagent.service.knowledge.DomainKnowledgeSearchService.DomainKnowledgeSearchRequest; import com.fasterxml.jackson.databind.JsonNode; @@ -40,9 +44,9 @@ public class DomainBusinessKnowledgeToolSupport { "type": "string", "description": "必填。需要检索的业务问题、指标名、术语、SOP 主题或案例主题。" }, - "knowledgeTypes": { - "type": "array", - "description": "可选。限定知识范围。支持 businessTerm、agentKnowledge、document、qa、faq、all。", + "knowledgeTypes": { + "type": "array", + "description": "可选。限定知识范围。支持 businessKnowledge、agentKnowledge、document、qa、faq、all。", "items": { "type": "string" } @@ -99,10 +103,16 @@ public ToolDefinition getToolDefinition() { @Override public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { try { JsonNode jsonNode = StringUtils.hasText(toolInput) ? objectMapper.readTree(toolInput) : objectMapper.createObjectNode(); String query = jsonNode.path("query").asText(""); + validateQuery(query); List knowledgeTypes = new ArrayList<>(); JsonNode knowledgeTypesNode = jsonNode.path("knowledgeTypes"); if (knowledgeTypesNode.isArray()) { @@ -120,16 +130,30 @@ public String call(String toolInput) { DomainKnowledgeSearchRequest request = new DomainKnowledgeSearchRequest(query, knowledgeTypes.isEmpty() ? null : List.copyOf(knowledgeTypes), topK, similarityThreshold); - return objectMapper.writeValueAsString(domainKnowledgeSearchService.search(agentId, request)); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper + .writeValueAsString(domainKnowledgeSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { - throw new IllegalStateException("Failed to search domain business knowledge: " + ex.getMessage(), ex); + throw new IllegalStateException(objectToJson(ToolError.of(ToolErrorCode.EXECUTION_FAILED, + "domain_business_knowledge.search 执行失败:" + ex.getMessage())), ex); } } - @Override - public String call(String toolInput, ToolContext toolContext) { - return call(toolInput); + private void validateQuery(String query) { + if (!StringUtils.hasText(query)) { + throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.INVALID_INPUT, "domain_business_knowledge.search 需要 query 参数"))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; + } } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchHit.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchHit.java index fa4abb850..9378ce70b 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchHit.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchHit.java @@ -40,4 +40,6 @@ public class SemanticModelSearchHit { private String matchedBy; + private Integer score; + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java index c1212d3ac..4624590b4 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchResult.java @@ -24,7 +24,7 @@ @Builder public class SemanticModelSearchResult { - private String query; + private String resolution; private String summary; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java index fb5a3eba6..34dcef71e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelSearchService.java @@ -15,8 +15,10 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.semantic; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.entity.SemanticModel; +import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.service.semantic.SemanticModelService; import java.util.ArrayList; @@ -27,6 +29,7 @@ import java.util.Objects; import java.util.Set; import lombok.RequiredArgsConstructor; +import org.springframework.lang.Nullable; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -41,44 +44,52 @@ public class SemanticModelSearchService { private final SemanticModelService semanticModelService; + private final AnswerTraceExplainStore answerTraceExplainStore; + public SemanticModelSearchResult search(String agentId, SemanticModelSearchRequest request) { + return search(agentId, request, null); + } + + public SemanticModelSearchResult search(String agentId, SemanticModelSearchRequest request, + @Nullable AgentRequest agentRequest) { if (!StringUtils.hasText(agentId)) { - return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search requires a numeric agent id."); + return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); } Long parsedAgentId; try { parsedAgentId = Long.valueOf(agentId); } catch (NumberFormatException ex) { - return emptyResult(request == null ? null : request.getQuery(), - "semantic_model.search requires a numeric agent id."); + return emptyResult(request == null ? null : request.getQuery(), "semantic_model.search 需要数值型 agentId 参数。"); } - return search(parsedAgentId, request); + return search(parsedAgentId, request, agentRequest); } public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest request) { + return search(agentId, request, null); + } + + public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest request, + @Nullable AgentRequest agentRequest) { String query = request == null ? null : request.getQuery(); if (!StringUtils.hasText(query)) { - throw new IllegalArgumentException("query is required for semantic_model.search"); + throw new IllegalArgumentException("semantic_model.search 需要 query 参数"); } AgentDatasource activeDatasource = resolveActiveDatasource(agentId); if (activeDatasource == null || activeDatasource.getDatasourceId() == null) { - return emptyResult(query, "No active datasource is available for semantic_model.search."); + return emptyResult(query, "当前没有可用于 semantic_model.search 的活动数据源。"); } TableSearchScope scope = resolveTableSearchScope(activeDatasource, request == null ? null : request.getTableNames()); if (scope.isScoped() && CollectionUtils.isEmpty(scope.getTableNames())) { - return emptyResult(query, - "Requested tables are outside the active datasource visibility scope for semantic_model.search."); + return emptyResult(query, "请求中指定的表超出了当前活动数据源对 semantic_model.search 的可见范围。"); } List candidates = scope.isUnbounded() ? semanticModelService.getEnabledByAgentIdAndDatasourceId(agentId, activeDatasource.getDatasourceId()) : semanticModelService.getEnabledByAgentIdAndDatasourceIdAndTableNames(agentId, activeDatasource.getDatasourceId(), scope.getTableNames()); if (CollectionUtils.isEmpty(candidates)) { - return emptyResult(query, - "No enabled semantic model entries matched this agent/table scope. Use datasource explorer for physical schema details."); + return emptyResult(query, "当前 Agent/表范围内没有匹配的已启用语义模型条目;物理表结构请改用数据源探索工具查看。"); } List scoredHits = candidates.stream() @@ -96,22 +107,23 @@ public SemanticModelSearchResult search(Long agentId, SemanticModelSearchRequest .toList(); if (scoredHits.isEmpty()) { - return emptyResult(query, - "No supplemental semantic hints matched the query. If datasource explorer already answers the schema question, do not call semantic_model.search."); + return emptyResult(query, "没有匹配到补充语义提示;如果数据源探索工具已能回答物理表结构问题,就不要额外调用 semantic_model.search。"); } List hits = scoredHits.stream().map(this::toHit).toList(); - return SemanticModelSearchResult.builder() - .query(query) - .summary( - "Found %d supplemental semantic hints. These are auxiliary explanations for table/column understanding, not a replacement for datasource exploration." - .formatted(hits.size())) - .hits(hits) - .build(); + String summary = "共匹配到 %d 条补充语义提示。这些结果只用于补充理解表和字段语义,不能替代数据源探索工具的物理结构探索。".formatted(hits.size()); + if (agentRequest != null) { + answerTraceExplainStore.recordSemanticSearch(agentRequest, query, "共匹配到 %d 条补充语义提示".formatted(hits.size()), + hits); + } + else { + answerTraceExplainStore.recordSemanticSearch(query, "共匹配到 %d 条补充语义提示".formatted(hits.size()), hits); + } + return SemanticModelSearchResult.builder().summary(summary).hits(hits).resolution("matched").build(); } private SemanticModelSearchResult emptyResult(String query, String summary) { - return SemanticModelSearchResult.builder().query(query).summary(summary).build(); + return SemanticModelSearchResult.builder().summary(summary).resolution("no_match").build(); } private AgentDatasource resolveActiveDatasource(Long agentId) { @@ -135,6 +147,7 @@ private SemanticModelSearchHit toHit(ScoredHit scoredHit) { .dataType(model.getDataType()) .relationHint(extractRelationHint(model)) .matchedBy(String.join(", ", scoredHit.getMatchedBy())) + .score(scoredHit.getScore()) .build(); } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java index 5beb1cad7..dbd85e1d0 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolProvider.java @@ -29,11 +29,11 @@ public class SemanticModelToolProvider implements AgentScopedToolProvider { private static final String TOOL_NAME = "semantic_model.search"; private static final String DESCRIPTION = """ - Supplemental semantic hint tool for table/column understanding only. - Use this tool when the user asks what a table/column means, asks for a business-friendly name, asks for enum meaning, asks for field usage notes, or asks for relation hints that are not explicitly stored in the physical schema. - Typical examples: "token名称类型", "status字段什么意思", "这个字段有哪些别名", "这两个表可能怎么关联". - Use datasource explorer first for physical schema, column lists, data types already in the database, previews, and readonly SQL. - Do not use this tool for SQL execution, schema discovery already covered by datasource explorer, or business definitions/metric rules/SOPs that belong to domain_business_knowledge.search. + 仅用于补充理解表和字段语义的辅助工具。 + 当用户在询问某张表或某个字段的含义、业务友好名称、枚举含义、字段使用备注,或数据库物理表结构中未显式存储的关系提示时,才使用本工具。 + 典型问题包括:“token 名称类型”“status 字段什么意思”“这个字段有哪些别名”“这两个表可能怎么关联”。 + 数据库里的物理表结构、字段列表、字段类型、样例预览和只读 SQL,应优先使用数据源探索工具获取。 + 不要把本工具用于 SQL 执行、数据源探索工具已能覆盖的表结构探索,或属于 `domain_business_knowledge.search` 的业务定义、指标口径和 SOP 检索。 """; private final SemanticModelToolSupport toolSupport; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java index 48b7feb95..f9fd09cef 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/semantic/SemanticModelToolSupport.java @@ -15,6 +15,10 @@ */ package com.alibaba.cloud.ai.dataagent.agentscope.tool.semantic; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.model.ToolContext; @@ -37,7 +41,7 @@ public class SemanticModelToolSupport { }, "tableNames": { "type": "array", - "description": "可选。将检索范围限制在这些表内;如果 datasource explorer 已能定位表结构,则不必传该工具。", + "description": "可选。将检索范围限制在这些表内;如果数据源探索工具已能定位表结构,则不必传该工具。", "items": { "type": "string" } @@ -85,20 +89,41 @@ public ToolDefinition getToolDefinition() { @Override public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { try { SemanticModelSearchRequest request = StringUtils.hasText(toolInput) ? objectMapper.readValue(toolInput, SemanticModelSearchRequest.class) : new SemanticModelSearchRequest(); - return objectMapper.writeValueAsString(semanticModelSearchService.search(agentId, request)); + validateRequest(request); + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + return objectMapper + .writeValueAsString(semanticModelSearchService.search(agentId, request, agentRequest)); } catch (Exception ex) { - throw new IllegalStateException("Failed to search semantic model hints: " + ex.getMessage(), ex); + throw new IllegalStateException(objectToJson( + ToolError.of(ToolErrorCode.EXECUTION_FAILED, "semantic_model.search 执行失败:" + ex.getMessage())), + ex); } } - @Override - public String call(String toolInput, ToolContext toolContext) { - return call(toolInput); + private void validateRequest(SemanticModelSearchRequest request) { + if (request == null || !StringUtils.hasText(request.getQuery())) { + throw new IllegalArgumentException( + objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, "semantic_model.search 需要 query 参数"))); + } + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; + } } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java index 7d70cdd15..63d7bea33 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/skilltool/BuiltinCurrentTimeSkillToolProvider.java @@ -101,10 +101,10 @@ public String call(String toolInput) { .toString(); } catch (IllegalArgumentException ex) { - throw new IllegalStateException("Invalid timezone or time format: " + ex.getMessage(), ex); + throw new IllegalStateException("时区或时间格式无效:" + ex.getMessage(), ex); } catch (Exception ex) { - throw new IllegalStateException("Failed to get current time: " + ex.getMessage(), ex); + throw new IllegalStateException("获取当前时间失败:" + ex.getMessage(), ex); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java new file mode 100644 index 000000000..1b508dd97 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckRequest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import java.util.List; +import lombok.Data; +import org.apache.commons.lang3.StringUtils; + +@Data +class SqlGuardCheckRequest { + + private String action; + + private String query; + + private String sql; + + private String humanFeedbackContent; + + private String tableName; + + private List columnNames; + + private Integer limit; + + String normalizedAction() { + return StringUtils.defaultIfBlank(action, "SQL_VERIFY").trim().toUpperCase(); + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java new file mode 100644 index 000000000..81e6e868f --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardCheckResult.java @@ -0,0 +1,54 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +@JsonInclude(JsonInclude.Include.NON_NULL) +class SqlGuardCheckResult { + + private String decision; + + private String tableName; + + private String summary; + + @JsonProperty("isAligned") + private Boolean isAligned; + + private Long totalRows; + + @Builder.Default + private List problems = new ArrayList<>(); + + @Builder.Default + private List fixSuggestions = new ArrayList<>(); + + @Builder.Default + private List ruleChecks = new ArrayList<>(); + + @Builder.Default + private List> columnProfiles = new ArrayList<>(); + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java new file mode 100644 index 000000000..afab679b2 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardProblem.java @@ -0,0 +1,43 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +class SqlGuardProblem { + + private String code; + + private String title; + + private String severity; + + private String message; + + private String why; + + private String expected; + + private String actual; + + private String evidence; + + private String repairHint; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java new file mode 100644 index 000000000..37f1b9ccc --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardRuleCheck.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +class SqlGuardRuleCheck { + + private String code; + + private String title; + + private String status; + + private String detail; + + private String evidence; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java new file mode 100644 index 000000000..f1c802eef --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlGuardToolProvider.java @@ -0,0 +1,195 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.runtime.ToolContextRequestResolver; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.AgentScopedToolProvider; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolError; +import com.alibaba.cloud.ai.dataagent.agentscope.tool.ToolErrorCode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Map; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Component +public class SqlGuardToolProvider implements AgentScopedToolProvider { + + private static final String TOOL_NAME = "sql_guard.check"; + + private static final String INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["SQL_VERIFY", "DATA_PROFILE"], + "description": "可选。默认 SQL_VERIFY。" + }, + "query": { + "type": "string", + "description": "SQL_VERIFY 时必填。用户原始问题。" + }, + "sql": { + "type": "string", + "description": "SQL_VERIFY 时必填。待校验 SQL。" + }, + "tableName": { + "type": "string", + "description": "DATA_PROFILE 时必填。目标表名。" + }, + "columnNames": { + "type": "array", + "items": { + "type": "string" + }, + "description": "DATA_PROFILE 时可选。优先只传少量关键字段。" + }, + "limit": { + "type": "integer", + "description": "DATA_PROFILE 时可选。返回上限,默认 5,最大 20。" + } + } + } + """; + + private static final String DESCRIPTION = """ + 统一 SQL 守卫工具,供所有基于 SQL 的回答使用。 + 1. `action=SQL_VERIFY`:在执行 SQL 或基于 SQL 生成最终回答前,检查候选 SQL 是否真正符合用户意图。 + 2. `action=DATA_PROFILE`:只有在完成 schema 检查后,仍有少量关键候选字段语义不明确,且这种不确定性会实质影响过滤、分组、排序、时间窗口或指标写法时,才用于补充查看字段值分布。 + 3. 不要把 DATA_PROFILE 当作每次查询的默认前置步骤;如果用户问题、schema 和列名已经足够明确,就直接跳过。 + 4. 使用 DATA_PROFILE 时,优先传少量关键 `columnNames`,不要对整张表做无差别 profile。 + 5. 如果 SQL_VERIFY 返回 `isAligned=false`,请读取 `problems`、`ruleChecks` 和 `fixSuggestions`,自行改写 SQL 后再次调用 `sql_guard.check`。 + 6. 如果使用 DATA_PROFILE,请重点读取返回的 `columnProfiles`,理解空值率、去重计数、高频值、样例值,以及字段更像枚举、数值还是时间字段。 + """; + + private final ObjectMapper objectMapper; + + private final SqlVerifyExplainService sqlVerifyExplainService; + + public SqlGuardToolProvider(ObjectMapper objectMapper, SqlVerifyExplainService sqlVerifyExplainService) { + this.objectMapper = objectMapper; + this.sqlVerifyExplainService = sqlVerifyExplainService; + } + + @Override + public Map getToolCallbacks(String agentId) { + ToolDefinition toolDefinition = ToolDefinition.builder() + .name(TOOL_NAME) + .description(DESCRIPTION) + .inputSchema(INPUT_SCHEMA) + .build(); + return Map.of(TOOL_NAME, + new SqlGuardToolCallback(agentId, toolDefinition, objectMapper, sqlVerifyExplainService)); + } + + private static final class SqlGuardToolCallback implements ToolCallback { + + private final String agentId; + + private final ToolDefinition toolDefinition; + + private final ObjectMapper objectMapper; + + private final SqlVerifyExplainService sqlVerifyExplainService; + + private SqlGuardToolCallback(String agentId, ToolDefinition toolDefinition, ObjectMapper objectMapper, + SqlVerifyExplainService sqlVerifyExplainService) { + this.agentId = agentId; + this.toolDefinition = toolDefinition; + this.objectMapper = objectMapper; + this.sqlVerifyExplainService = sqlVerifyExplainService; + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return execute(toolInput, null); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return execute(toolInput, toolContext); + } + + private String execute(String toolInput, ToolContext toolContext) { + try { + SqlGuardCheckRequest request = StringUtils.hasText(toolInput) + ? objectMapper.readValue(toolInput, SqlGuardCheckRequest.class) : new SqlGuardCheckRequest(); + enrichRequestFromToolContext(request, toolContext); + String action = request.normalizedAction(); + validateRequest(request, action); + SqlGuardCheckResult result = switch (action) { + case "DATA_PROFILE" -> sqlVerifyExplainService.inspectProfile(agentId, request); + case "SQL_VERIFY" -> sqlVerifyExplainService.explain(request); + default -> throw new IllegalArgumentException(objectToJson( + ToolError.of(ToolErrorCode.UNSUPPORTED_ACTION, "不支持的 sql_guard.check 动作:" + action))); + }; + return objectMapper.writeValueAsString(result); + } + catch (Exception ex) { + throw new IllegalStateException(objectToJson( + ToolError.of(ToolErrorCode.EXECUTION_FAILED, "sql_guard.check 执行失败:" + ex.getMessage())), ex); + } + } + + private void validateRequest(SqlGuardCheckRequest request, String action) { + if ("DATA_PROFILE".equals(action)) { + requireText(request.getTableName(), "DATA_PROFILE 需要 tableName 参数"); + return; + } + if ("SQL_VERIFY".equals(action)) { + requireText(request.getQuery(), "SQL_VERIFY 需要 query 参数"); + requireText(request.getSql(), "SQL_VERIFY 需要 sql 参数"); + } + } + + private void requireText(String value, String message) { + if (!StringUtils.hasText(value)) { + throw new IllegalArgumentException(objectToJson(ToolError.of(ToolErrorCode.INVALID_INPUT, message))); + } + } + + private void enrichRequestFromToolContext(SqlGuardCheckRequest request, ToolContext toolContext) { + if (request == null || StringUtils.hasText(request.getHumanFeedbackContent())) { + return; + } + AgentRequest agentRequest = ToolContextRequestResolver.resolveGraphRequest(toolContext); + if (agentRequest == null) { + return; + } + request.setHumanFeedbackContent(agentRequest.getHumanFeedbackContent()); + } + + private String objectToJson(Object value) { + try { + return objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + return "{\"code\":\"EXECUTION_FAILED\",\"message\":\"工具错误序列化失败\"}"; + } + } + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java new file mode 100644 index 000000000..11614b616 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/tool/sqlguard/SqlVerifyExplainService.java @@ -0,0 +1,1285 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.agentscope.tool.sqlguard; + +import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ColumnInfoBO; +import com.alibaba.cloud.ai.dataagent.bo.schema.ResultSetBO; +import com.alibaba.cloud.ai.dataagent.connector.DbQueryParameter; +import com.alibaba.cloud.ai.dataagent.connector.accessor.Accessor; +import com.alibaba.cloud.ai.dataagent.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; +import com.alibaba.cloud.ai.dataagent.entity.Datasource; +import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; +import com.alibaba.cloud.ai.dataagent.service.datasource.DatasourceService; +import com.alibaba.cloud.ai.dataagent.util.SqlUtil; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.util.TablesNamesFinder; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; + +@Service +public class SqlVerifyExplainService { + + private static final String ACTION_SQL_VERIFY = "SQL_VERIFY"; + + private static final String ACTION_DATA_PROFILE = "DATA_PROFILE"; + + private static final int DEFAULT_PROFILE_LIMIT = 5; + + private static final int MAX_PROFILE_LIMIT = 20; + + private static final int DEFAULT_PROFILE_COLUMN_COUNT = 3; + + private static final Pattern AGGREGATE_PATTERN = Pattern + .compile("(?i)\\b(count|sum|avg|average|min|max)\\s*\\(([^)]*)\\)\\s*(?:as\\s+([a-zA-Z0-9_]+))?"); + + private static final Pattern GROUP_BY_PATTERN = Pattern.compile("(?is)\\bgroup\\s+by\\b"); + + private static final Pattern ORDER_BY_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b"); + + private static final Pattern LIMIT_PATTERN = Pattern.compile("(?is)\\blimit\\s+\\d+\\b"); + + private static final Pattern TOP_PATTERN = Pattern.compile("(?is)\\bselect\\s+top\\s+\\d+\\b"); + + private static final Pattern FETCH_FIRST_PATTERN = Pattern + .compile("(?is)\\bfetch\\s+first\\s+\\d+\\s+rows\\s+only\\b"); + + private static final Pattern DISTINCT_PATTERN = Pattern + .compile("(?is)\\bselect\\s+distinct\\b|count\\s*\\(\\s*distinct\\b"); + + private static final Pattern DATE_LITERAL_PATTERN = Pattern + .compile("\\b\\d{4}[-/]\\d{1,2}[-/]\\d{1,2}\\b|\\b\\d{6,8}\\b"); + + private static final Pattern TIME_FUNCTION_PATTERN = Pattern.compile( + "(?is)\\b(current_date|current_timestamp|now\\s*\\(|curdate\\s*\\(|date\\s*\\(|date_trunc\\s*\\(|strftime\\s*\\(|to_date\\s*\\(|to_timestamp\\s*\\(|datediff\\s*\\(|dateadd\\s*\\(|timestampdiff\\s*\\(|interval\\b)"); + + private static final Pattern WHERE_PATTERN = Pattern.compile("(?is)\\bwhere\\b"); + + private static final Pattern DESC_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b.+?\\bdesc\\b"); + + private static final Pattern ASC_PATTERN = Pattern.compile("(?is)\\border\\s+by\\b.+?\\basc\\b"); + + private static final Pattern TOP_N_QUERY_PATTERN = Pattern + .compile("(?i)(?:\\btop\\s*(\\d+)\\b|前\\s*(\\d+)\\s*(?:个|名|条)?|排名前\\s*(\\d+)\\s*(?:个|名)?)"); + + private static final Pattern SQL_LIMIT_VALUE_PATTERN = Pattern.compile("(?is)\\blimit\\s+(\\d+)\\b"); + + private static final Pattern SQL_TOP_VALUE_PATTERN = Pattern.compile("(?is)\\bselect\\s+top\\s+(\\d+)\\b"); + + private static final Pattern SQL_FETCH_FIRST_VALUE_PATTERN = Pattern + .compile("(?is)\\bfetch\\s+first\\s+(\\d+)\\s+rows\\s+only\\b"); + + private static final Pattern STATUS_COLUMN_PATTERN = Pattern + .compile("(?is)\\b(status|order_status|payment_status|trade_status|state)\\b"); + + private static final Pattern NEGATIVE_STATUS_OPERATOR_PATTERN = Pattern + .compile("(?is)(<>|!=|not\\s+in\\s*\\(|not\\s+like\\b)"); + + private final AgentDatasourceService agentDatasourceService; + + private final DatasourceService datasourceService; + + private final AccessorFactory accessorFactory; + + public SqlVerifyExplainService(AgentDatasourceService agentDatasourceService, DatasourceService datasourceService, + AccessorFactory accessorFactory) { + this.agentDatasourceService = agentDatasourceService; + this.datasourceService = datasourceService; + this.accessorFactory = accessorFactory; + } + + public SqlGuardCheckResult explain(SqlGuardCheckRequest request) { + String query = StringUtils.trimToEmpty(request == null ? null : request.getQuery()); + String sql = StringUtils.trimToEmpty(request == null ? null : request.getSql()); + String humanFeedbackContent = StringUtils + .trimToEmpty(request == null ? null : request.getHumanFeedbackContent()); + if (StringUtils.isBlank(query)) { + throw new IllegalArgumentException("sql_guard.check 需要 query"); + } + if (StringUtils.isBlank(sql)) { + throw new IllegalArgumentException("sql_guard.check 需要 sql"); + } + + String effectiveIntentSource = mergeIntentSource(query, humanFeedbackContent); + QueryIntent intent = analyzeQueryIntent(effectiveIntentSource); + HumanFeedbackConstraint feedbackConstraint = analyzeHumanFeedbackConstraint(humanFeedbackContent); + + Statement statement; + try { + statement = parseSingleSelectStatement(sql); + } + catch (IllegalArgumentException ex) { + return SqlGuardCheckResult.builder() + .decision("revise_sql") + .isAligned(false) + .summary("SQL 无法通过语法解析,无法继续做结构和意图一致性校验。") + .problems(List.of(SqlGuardProblem.builder() + .code("SQL_PARSE_ERROR") + .title("SQL 语法解析失败") + .severity("high") + .message("SQL 无法解析,当前结果不能视为已校验通过。") + .why("校验器必须先把 SQL 解析成合法的 SELECT / WITH 语法树,才能继续检查聚合、分组、排序和时间窗口。") + .expected("输入应为单条可解析的只读 SELECT / WITH 查询。") + .actual("当前 SQL 在语法层面未通过解析。") + .evidence(ex.getMessage()) + .repairHint("先修复括号、关键字顺序、逗号、别名或多语句拼接问题,再重新调用 sql_guard.check。") + .build())) + .fixSuggestions(List.of("先修复 SQL 语法错误,再重新调用 sql_guard.check。")) + .ruleChecks(List.of(SqlGuardRuleCheck.builder() + .code("SQL_PARSE") + .title("SQL 语法解析") + .status("FAILED") + .detail("当前 SQL 未通过语法解析,后续结构规则无法继续执行。") + .evidence(ex.getMessage()) + .build())) + .build(); + } + + SqlShape shape = analyzeSqlShape(statement, sql, request); + List problems = new ArrayList<>(); + Set fixSuggestions = new LinkedHashSet<>(); + List ruleChecks = new ArrayList<>(); + + evaluateAggregationRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateGroupingRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeFilterRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeBucketRule(sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateTimeOrderRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateOrderingRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateLimitRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateDistinctRule(query, sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateOrderDirectionRule(sql, intent, shape, problems, fixSuggestions, ruleChecks); + evaluateHumanFeedbackRule(query, sql, humanFeedbackContent, feedbackConstraint, problems, fixSuggestions, + ruleChecks); + + boolean aligned = problems.stream().noneMatch(problem -> isBlockingSeverity(problem.getSeverity())); + String summary = aligned ? "SQL 通过了当前规则版意图一致性校验。" : "检测到 %d 个可能影响答案正确性的意图一致性问题。".formatted(problems.size()); + if (aligned) { + fixSuggestions.add("当前规则校验通过;如要进一步提高置信度,可继续核对执行结果与最终答案解释。"); + } + return SqlGuardCheckResult.builder() + .decision(aligned ? "safe_to_execute" : "revise_sql") + .isAligned(aligned) + .summary(summary) + .problems(problems) + .fixSuggestions(List.copyOf(fixSuggestions)) + .ruleChecks(ruleChecks) + .build(); + } + + public SqlGuardCheckResult inspectProfile(String agentId, SqlGuardCheckRequest request) { + String tableName = StringUtils.trimToEmpty(request == null ? null : request.getTableName()); + if (StringUtils.isBlank(tableName)) { + throw new IllegalArgumentException("sql_guard.check 在 action=DATA_PROFILE 时必须提供 tableName"); + } + ProfileContext context = resolveProfileContext(agentId); + String actualTableName = resolveVisibleTableName(context, tableName); + List availableColumns = loadTableColumns(context, actualTableName); + List visibleColumns = applyVisibleColumnRestrictions(context, actualTableName, availableColumns); + if (visibleColumns.isEmpty()) { + throw new IllegalArgumentException("表 '%s' 在当前 Agent 下没有可见字段".formatted(actualTableName)); + } + List columnsToInspect = resolveColumnsToInspect(request, actualTableName, visibleColumns); + int sampleLimit = normalizeProfileLimit(request == null ? null : request.getLimit()); + long totalRows = querySingleLong(context, + "SELECT COUNT(*) AS total_rows FROM " + quoteTable(context, actualTableName), "total_rows"); + List> columnProfiles = columnsToInspect.stream() + .map(column -> buildColumnProfile(context, actualTableName, column, totalRows, sampleLimit)) + .toList(); + String summary = "仅基于可见字段对表 '%s' 的 %d 个字段完成 profile 分析。".formatted(columnProfiles.size(), actualTableName); + return SqlGuardCheckResult.builder() + .decision("inspect_columns") + .tableName(actualTableName) + .summary(summary) + .totalRows(totalRows) + .columnProfiles(columnProfiles) + .fixSuggestions( + List.of("可优先把高频值集中的分类字段用作过滤条件或 GROUP BY 候选字段。", "可优先把具备 min/max 范围的数值或时间字段用作指标、趋势或时间窗口候选字段。")) + .build(); + } + + private ProfileContext resolveProfileContext(String agentId) { + if (!StringUtils.isNumeric(agentId)) { + throw new IllegalArgumentException("sql_guard.check 的 DATA_PROFILE 当前仅支持数值型 agentId"); + } + Long numericAgentId = Long.valueOf(agentId); + AgentDatasource agentDatasource = agentDatasourceService.getCurrentAgentDatasource(numericAgentId); + Datasource datasource = agentDatasource.getDatasource() != null ? agentDatasource.getDatasource() + : datasourceService.getDatasourceById(agentDatasource.getDatasourceId()); + if (datasource == null) { + throw new IllegalStateException("当前 Agent 未找到活动数据源:" + agentId); + } + DbConfigBO dbConfig = datasourceService.getDbConfig(datasource); + Accessor accessor = accessorFactory.getAccessorByDbConfig(dbConfig); + List explicitSelectedTables = Optional.ofNullable(agentDatasource.getSelectTables()).orElse(List.of()); + List visibleTables; + try { + visibleTables = explicitSelectedTables.isEmpty() ? datasourceService.getDatasourceTables(datasource.getId()) + : explicitSelectedTables; + } + catch (Exception ex) { + throw new IllegalStateException("加载数据源 %s 的可见表失败:%s".formatted(datasource.getId(), ex.getMessage()), ex); + } + Map> visibleTablesByName = indexTables(visibleTables, false); + Map> visibleTablesByLeafName = indexTables(visibleTables, true); + Map> visibleColumnsByTable = buildVisibleColumnsByTable(agentDatasource, + visibleTablesByName, visibleTablesByLeafName); + Map> visibleColumnNameSetByTable = new LinkedHashMap<>(); + visibleColumnsByTable.forEach((key, value) -> visibleColumnNameSetByTable.put(key, + value.stream() + .map(this::normalizeColumnName) + .collect(LinkedHashSet::new, LinkedHashSet::add, LinkedHashSet::addAll))); + return new ProfileContext(agentDatasource, datasource, dbConfig, accessor, List.copyOf(visibleTables), + Map.copyOf(visibleTablesByName), Map.copyOf(visibleTablesByLeafName), Map.copyOf(visibleColumnsByTable), + Map.copyOf(visibleColumnNameSetByTable), Set.copyOf(visibleColumnsByTable.keySet())); + } + + private List loadTableColumns(ProfileContext context, String tableName) { + try { + return Optional + .ofNullable(context.accessor() + .showColumns(context.dbConfig(), + DbQueryParameter.from(context.dbConfig()) + .setSchema(context.dbConfig().getSchema()) + .setTable(tableName))) + .orElse(List.of()); + } + catch (Exception ex) { + throw new IllegalStateException("加载表 '%s' 的字段失败:%s".formatted(tableName, ex.getMessage()), ex); + } + } + + private List applyVisibleColumnRestrictions(ProfileContext context, String tableName, + List columns) { + return Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(column -> isColumnVisible(context, tableName, column.getName())) + .toList(); + } + + private List resolveColumnsToInspect(SqlGuardCheckRequest request, String tableName, + List visibleColumns) { + Map columnsByName = new LinkedHashMap<>(); + for (ColumnInfoBO column : visibleColumns) { + columnsByName.put(normalizeColumnName(column.getName()), column); + } + List requestedColumns = Optional.ofNullable(request == null ? null : request.getColumnNames()) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .toList(); + if (requestedColumns.isEmpty()) { + return visibleColumns.stream().limit(DEFAULT_PROFILE_COLUMN_COUNT).toList(); + } + List resolvedColumns = new ArrayList<>(); + for (String requestedColumn : requestedColumns) { + ColumnInfoBO column = columnsByName.get(normalizeColumnName(requestedColumn)); + if (column == null) { + throw new IllegalArgumentException( + "字段 '%s' 在表 '%s' 中对当前 Agent 不可见".formatted(requestedColumn, tableName)); + } + resolvedColumns.add(column); + } + return resolvedColumns; + } + + private Map buildColumnProfile(ProfileContext context, String tableName, ColumnInfoBO column, + long totalRows, int sampleLimit) { + String quotedTable = quoteTable(context, tableName); + String quotedColumn = SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), column.getName()); + long nullCount = querySingleLong(context, + "SELECT COUNT(*) AS null_rows FROM %s WHERE %s IS NULL".formatted(quotedTable, quotedColumn), + "null_rows"); + Double nullRatio = totalRows <= 0 ? 0D : roundRatio((double) nullCount / (double) totalRows); + Long distinctCount = null; + if (supportsDistinctCount(column)) { + distinctCount = querySingleLong(context, + "SELECT COUNT(DISTINCT %s) AS distinct_count FROM %s".formatted(quotedColumn, quotedTable), + "distinct_count"); + } + List> topValues = supportsGroupedTopValues(column) + ? queryTopValues(context, quotedTable, quotedColumn, sampleLimit) : List.of(); + List sampleValues = querySampleValues(context, quotedTable, quotedColumn, sampleLimit, + supportsDistinctCount(column)); + String minValue = supportsMinMax(column) ? querySingleValue(context, + "SELECT MIN(%s) AS min_value FROM %s".formatted(quotedColumn, quotedTable), "min_value") : null; + String maxValue = supportsMinMax(column) ? querySingleValue(context, + "SELECT MAX(%s) AS max_value FROM %s".formatted(quotedColumn, quotedTable), "max_value") : null; + Map profile = new LinkedHashMap<>(); + profile.put("columnName", column.getName()); + profile.put("dataType", column.getType()); + profile.put("notNull", column.isNotnull()); + profile.put("nullCount", nullCount); + profile.put("nullRatio", nullRatio); + profile.put("distinctCount", distinctCount); + profile.put("sampleValues", sampleValues); + profile.put("topValues", topValues); + profile.put("min", minValue); + profile.put("max", maxValue); + profile.put("profileHints", buildProfileHints(column, nullRatio, distinctCount, totalRows, topValues)); + return profile; + } + + private List> queryTopValues(ProfileContext context, String quotedTable, String quotedColumn, + int sampleLimit) { + String sql = applyLimit(""" + SELECT %s AS profile_value, COUNT(*) AS profile_count + FROM %s + WHERE %s IS NOT NULL + GROUP BY %s + ORDER BY profile_count DESC + """.formatted(quotedColumn, quotedTable, quotedColumn, quotedColumn), + context.dbConfig().getDialectType(), sampleLimit); + ResultSetBO resultSet = executeSql(context, sql); + List> values = new ArrayList<>(); + for (Map row : Optional.ofNullable(resultSet.getData()).orElse(List.of())) { + Map entry = new LinkedHashMap<>(); + entry.put("value", row.get("profile_value")); + entry.put("count", parseLong(row.get("profile_count"))); + values.add(entry); + } + return values; + } + + private List querySampleValues(ProfileContext context, String quotedTable, String quotedColumn, + int sampleLimit, boolean distinctPreferred) { + String selectClause = distinctPreferred ? "SELECT DISTINCT %s AS sample_value".formatted(quotedColumn) + : "SELECT %s AS sample_value".formatted(quotedColumn); + String sql = applyLimit(""" + %s + FROM %s + WHERE %s IS NOT NULL + ORDER BY %s + """.formatted(selectClause, quotedTable, quotedColumn, quotedColumn), + context.dbConfig().getDialectType(), sampleLimit); + ResultSetBO resultSet = executeSql(context, sql); + return Optional.ofNullable(resultSet.getData()) + .orElse(List.of()) + .stream() + .map(row -> row.get("sample_value")) + .filter(StringUtils::isNotBlank) + .toList(); + } + + private List buildProfileHints(ColumnInfoBO column, Double nullRatio, Long distinctCount, long totalRows, + List> topValues) { + List hints = new ArrayList<>(); + if (Boolean.TRUE.equals(isLikelyCategorical(column, distinctCount, totalRows, topValues))) { + hints.add("该字段很可能是枚举或分类字段,适合用于过滤条件或 GROUP BY。"); + } + if (supportsMinMax(column)) { + hints.add("该字段很可能具备顺序语义,适合用于范围过滤、指标计算或趋势轴。"); + } + if (nullRatio != null && nullRatio >= 0.5D) { + hints.add("该字段空值比例较高,作为强过滤条件时需要谨慎。"); + } + if (hints.isEmpty()) { + hints.add("请先结合样例值和高频值判断,再决定是否将该字段写入 SQL。"); + } + return hints; + } + + private Boolean isLikelyCategorical(ColumnInfoBO column, Long distinctCount, long totalRows, + List> topValues) { + if (!supportsGroupedTopValues(column)) { + return false; + } + if (distinctCount != null && distinctCount > 0 && distinctCount <= 20) { + return true; + } + if (totalRows > 0 && distinctCount != null && distinctCount <= Math.max(10, totalRows / 10)) { + return true; + } + return !topValues.isEmpty() && topValues.size() <= 10; + } + + private long querySingleLong(ProfileContext context, String sql, String columnName) { + return parseLong(querySingleValue(context, sql, columnName)); + } + + private String querySingleValue(ProfileContext context, String sql, String columnName) { + ResultSetBO resultSet = executeSql(context, sql); + List> rows = Optional.ofNullable(resultSet.getData()).orElse(List.of()); + if (rows.isEmpty()) { + return null; + } + Map row = rows.get(0); + if (row.containsKey(columnName)) { + return row.get(columnName); + } + return row.values().stream().findFirst().orElse(null); + } + + private ResultSetBO executeSql(ProfileContext context, String sql) { + try { + ResultSetBO resultSet = context.accessor() + .executeSqlAndReturnObject(context.dbConfig(), + DbQueryParameter.from(context.dbConfig()) + .setSchema(context.dbConfig().getSchema()) + .setSql(sql)); + if (resultSet == null) { + return ResultSetBO.builder().column(List.of()).data(List.of()).build(); + } + if (StringUtils.isNotBlank(resultSet.getErrorMsg())) { + throw new IllegalStateException(resultSet.getErrorMsg()); + } + if (resultSet.getColumn() == null) { + resultSet.setColumn(List.of()); + } + if (resultSet.getData() == null) { + resultSet.setData(List.of()); + } + return resultSet; + } + catch (Exception ex) { + throw new IllegalStateException("执行 profile SQL 失败:" + ex.getMessage(), ex); + } + } + + private int normalizeProfileLimit(Integer requestedLimit) { + if (requestedLimit == null || requestedLimit <= 0) { + return DEFAULT_PROFILE_LIMIT; + } + return Math.min(requestedLimit, MAX_PROFILE_LIMIT); + } + + private boolean supportsDistinctCount(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return !containsAny(normalizedType, "blob", "clob", "text", "ntext", "image", "json", "xml", "bytea"); + } + + private boolean supportsGroupedTopValues(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return !containsAny(normalizedType, "blob", "clob", "ntext", "image", "bytea"); + } + + private boolean supportsMinMax(ColumnInfoBO column) { + String normalizedType = normalizeType(column); + return containsAny(normalizedType, "int", "number", "numeric", "decimal", "double", "float", "real", "date", + "time", "year", "timestamp"); + } + + private String normalizeType(ColumnInfoBO column) { + return StringUtils.defaultString(column == null ? null : column.getType()).toLowerCase(Locale.ROOT); + } + + private String applyLimit(String sql, String dialectType, int limit) { + String trimmed = StringUtils.trimToEmpty(sql); + if (trimmed.isEmpty()) { + return trimmed; + } + String normalizedDialect = StringUtils.defaultString(dialectType).toLowerCase(Locale.ROOT); + if (normalizedDialect.contains("sqlserver") || normalizedDialect.contains("sql_server")) { + if (trimmed.matches("(?is)^select\\s+distinct\\b.*")) { + return trimmed.replaceFirst("(?is)^select\\s+distinct\\b", "SELECT DISTINCT TOP %d".formatted(limit)); + } + return trimmed.replaceFirst("(?is)^select\\b", "SELECT TOP %d".formatted(limit)); + } + if (normalizedDialect.contains("oracle")) { + return trimmed + " FETCH FIRST " + limit + " ROWS ONLY"; + } + return trimmed + " LIMIT " + limit; + } + + private String quoteTable(ProfileContext context, String tableName) { + return SqlUtil.quoteIdentifier(context.dbConfig().getDialectType(), tableName); + } + + private double roundRatio(double value) { + return Math.round(value * 10000D) / 10000D; + } + + private long parseLong(String value) { + if (StringUtils.isBlank(value)) { + return 0L; + } + try { + return Long.parseLong(value.trim()); + } + catch (NumberFormatException ex) { + try { + return Math.round(Double.parseDouble(value.trim())); + } + catch (NumberFormatException ignored) { + return 0L; + } + } + } + + private String resolveVisibleTableName(ProfileContext context, String tableName) { + return findVisibleTableName(context.visibleTablesByName(), context.visibleTablesByLeafName(), tableName, false) + .orElseThrow(() -> new IllegalArgumentException( + "表 '%s' 对当前 Agent 不可见。当前可见表:%s".formatted(tableName, String.join(", ", context.visibleTables())))); + } + + private Optional findVisibleTableName(Map> visibleTablesByName, + Map> visibleTablesByLeafName, String tableName, boolean allowQualifiedFallback) { + String normalizedTableName = normalizeIdentifier(tableName); + List exactMatches = visibleTablesByName.getOrDefault(normalizedTableName, List.of()); + if (exactMatches.size() == 1) { + return Optional.of(exactMatches.get(0)); + } + if (exactMatches.size() > 1) { + throw new IllegalArgumentException( + "表 '%s' 映射到了多张当前可见表:%s".formatted(tableName, String.join(", ", exactMatches))); + } + if (isQualifiedIdentifier(tableName) && !allowQualifiedFallback) { + return Optional.empty(); + } + List leafMatches = visibleTablesByLeafName.getOrDefault(normalizeTableLeafName(tableName), List.of()); + if (leafMatches.size() == 1) { + return Optional.of(leafMatches.get(0)); + } + if (leafMatches.size() > 1) { + throw new IllegalArgumentException( + "表 '%s' 在当前可见表范围内存在歧义:%s".formatted(tableName, String.join(", ", leafMatches))); + } + return Optional.empty(); + } + + private Map> buildVisibleColumnsByTable(AgentDatasource agentDatasource, + Map> visibleTablesByName, Map> visibleTablesByLeafName) { + Map> selectedColumns = Optional.ofNullable(agentDatasource.getSelectColumns()) + .orElse(Map.of()); + Map> visibleColumnsByTable = new LinkedHashMap<>(); + selectedColumns.forEach((tableName, columns) -> { + Optional resolvedTableName = findVisibleTableName(visibleTablesByName, visibleTablesByLeafName, + tableName, true); + if (resolvedTableName.isEmpty()) { + return; + } + List sanitizedColumns = Optional.ofNullable(columns) + .orElse(List.of()) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + if (!sanitizedColumns.isEmpty()) { + visibleColumnsByTable.put(normalizeTableName(resolvedTableName.get()), sanitizedColumns); + } + }); + return visibleColumnsByTable; + } + + private Map> indexTables(List tableNames, boolean leafOnly) { + Map> index = new LinkedHashMap<>(); + for (String tableName : Optional.ofNullable(tableNames).orElse(List.of())) { + if (StringUtils.isBlank(tableName)) { + continue; + } + String key = leafOnly ? normalizeTableLeafName(tableName) : normalizeTableName(tableName); + index.computeIfAbsent(key, ignored -> new ArrayList<>()).add(tableName); + } + return index; + } + + private boolean isColumnVisible(ProfileContext context, String tableName, String columnName) { + String normalizedTableName = normalizeTableName(tableName); + if (!context.columnRestrictedTables().contains(normalizedTableName)) { + return true; + } + Set visibleColumns = context.visibleColumnNameSetByTable().get(normalizedTableName); + return visibleColumns != null && visibleColumns.contains(normalizeColumnName(columnName)); + } + + private boolean isQualifiedIdentifier(String value) { + return normalizeIdentifier(value).contains("."); + } + + private String normalizeIdentifier(String value) { + String normalized = StringUtils.trimToEmpty(value); + normalized = StringUtils.removeStart(normalized, "`"); + normalized = StringUtils.removeEnd(normalized, "`"); + normalized = StringUtils.removeStart(normalized, "\""); + normalized = StringUtils.removeEnd(normalized, "\""); + normalized = StringUtils.removeStart(normalized, "["); + normalized = StringUtils.removeEnd(normalized, "]"); + return normalized.toLowerCase(Locale.ROOT); + } + + private String normalizeTableName(String tableName) { + return normalizeIdentifier(tableName); + } + + private String normalizeTableLeafName(String tableName) { + String normalized = normalizeIdentifier(tableName); + int lastDot = normalized.lastIndexOf('.'); + return lastDot >= 0 ? normalized.substring(lastDot + 1) : normalized; + } + + private String normalizeColumnName(String columnName) { + return normalizeTableLeafName(columnName); + } + + private void evaluateAggregationRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresAggregation()) { + return; + } + if (!shape.hasAggregation()) { + addProblem(problems, fixSuggestions, "MISSING_AGGREGATION", "缺少聚合指标", "high", "问题看起来要求聚合指标,但 SQL 更像明细查询。", + "用户问题带有数量、金额、总数、平均值等聚合口径,但 SQL 没有检测到 count/sum/avg/min/max 等聚合函数。", "SELECT 中应包含与题目口径匹配的聚合表达式。", + "当前 SQL 未检测到聚合函数。", "query=" + query + "; sql=" + sql, + "把 count/sum/avg/min/max 等聚合逻辑补齐到 SELECT 中。"); + recordRuleCheck(ruleChecks, "AGGREGATION_REQUIRED", "聚合指标校验", "FAILED", "问题要求聚合指标,但 SQL 未检测到聚合函数。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "AGGREGATION_REQUIRED", "聚合指标校验", "PASSED", "问题要求聚合指标,SQL 已检测到聚合函数。", + "usedMetrics=" + shape.usedMetrics()); + } + + private void evaluateGroupingRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresGrouping()) { + return; + } + if (!shape.hasGroupBy()) { + addProblem(problems, fixSuggestions, "MISSING_GROUP_BY", "缺少 GROUP BY", "high", + "问题要求按维度拆分,但 SQL 缺少 GROUP BY。", "用户问题包含按地区、按用户、各品类、每月等拆分意图;没有 GROUP BY 时,要么结果被压成总计,要么数据库直接报错。", + "SQL 应按题目中的维度列做 GROUP BY。", "当前 SQL 未检测到 GROUP BY。", "query=" + query + "; sql=" + sql, + "把用户要求的维度列加入 GROUP BY,并检查 SELECT 中的非聚合列。"); + recordRuleCheck(ruleChecks, "GROUP_BY_REQUIRED", "分组维度校验", "FAILED", "问题要求分维度拆分,但 SQL 未检测到 GROUP BY。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "GROUP_BY_REQUIRED", "分组维度校验", "PASSED", "问题要求分维度拆分,SQL 已检测到 GROUP BY。", + "sql=" + sql); + } + + private void evaluateTimeFilterRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTimeFilter()) { + return; + } + if (!shape.hasTimePredicate()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_FILTER", "缺少时间过滤", "high", + "问题包含明确时间窗口,但 SQL 没有可靠的时间过滤信号。", "题目提到了今天、本月、最近30天、某年某月等时间范围,但 SQL 没有看到明确时间过滤,结果很可能回成全量数据。", + "WHERE 中应包含与题目对应的时间范围约束。", "当前 SQL 未检测到可靠的时间过滤表达式。", "query=" + query + "; sql=" + sql, + "在 WHERE 中补齐精确时间范围,不要按默认全量数据查询。"); + recordRuleCheck(ruleChecks, "TIME_FILTER_REQUIRED", "时间窗口校验", "FAILED", "问题包含明确时间窗口,但 SQL 未检测到可靠时间过滤。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_FILTER_REQUIRED", "时间窗口校验", "PASSED", "问题包含明确时间窗口,SQL 已检测到时间过滤。", + "sql=" + sql); + } + + private void evaluateTimeBucketRule(String sql, QueryIntent intent, SqlShape shape, List problems, + Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTrend()) { + return; + } + if (!shape.hasTimeBucket()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_BUCKET", "缺少时间分桶", "high", + "趋势类问题通常需要时间分桶,但 SQL 没看到明确的时间粒度表达。", "趋势分析需要先按天、周、月、年等粒度汇总;没有时间分桶,返回的往往只是总数,不是趋势。", + "SQL 应包含 DATE/DATE_TRUNC/DATE_FORMAT 等时间分桶表达式,并按同粒度分组。", "当前 SQL 未检测到明确的时间分桶表达式。", "sql=" + sql, + "用 DATE/DATE_TRUNC/DATE_FORMAT 等时间分桶表达式,并对同一时间粒度做 GROUP BY。"); + recordRuleCheck(ruleChecks, "TIME_BUCKET_REQUIRED", "趋势时间粒度校验", "FAILED", "趋势问题需要时间分桶,但 SQL 未检测到时间粒度表达。", + "sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_BUCKET_REQUIRED", "趋势时间粒度校验", "PASSED", "趋势问题所需的时间分桶表达已检测到。", "sql=" + sql); + } + + private void evaluateTimeOrderRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresTrend()) { + return; + } + if (!shape.hasOrderBy()) { + addProblem(problems, fixSuggestions, "MISSING_TIME_ORDER", "缺少时间排序", "medium", + "趋势类问题通常需要按时间排序,但 SQL 缺少 ORDER BY。", "趋势结果如果不按时间排序,输出顺序可能是乱的,后续回答和可视化都容易误导。", + "趋势 SQL 应按时间字段或时间分桶字段排序。", "当前 SQL 未检测到 ORDER BY。", "query=" + query + "; sql=" + sql, + "按时间字段或时间分桶字段补齐 ORDER BY。"); + recordRuleCheck(ruleChecks, "TIME_ORDER_REQUIRED", "趋势时间排序校验", "FAILED", "趋势问题需要时间排序,但 SQL 未检测到 ORDER BY。", + "sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "TIME_ORDER_REQUIRED", "趋势时间排序校验", "PASSED", "趋势问题所需的时间排序已检测到。", "sql=" + sql); + } + + private void evaluateOrderingRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresOrdering()) { + return; + } + if (!shape.hasOrderBy()) { + addProblem(problems, fixSuggestions, "MISSING_ORDER_BY", "缺少排序", "medium", "问题要求排序或排名,但 SQL 缺少 ORDER BY。", + "题目要求最高、最低、TopN、排名等比较关系;没有 ORDER BY 时,即使有限制行数,返回的也不一定是目标对象。", "SQL 应明确按目标指标排序。", + "当前 SQL 未检测到 ORDER BY。", "query=" + query + "; sql=" + sql, "根据问题要求补齐 ORDER BY,并明确升序还是降序。"); + recordRuleCheck(ruleChecks, "ORDER_REQUIRED", "排序要求校验", "FAILED", "问题包含排序或排名诉求,但 SQL 未检测到 ORDER BY。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "ORDER_REQUIRED", "排序要求校验", "PASSED", "问题包含排序或排名诉求,SQL 已检测到 ORDER BY。", + "sql=" + sql); + } + + private void evaluateLimitRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresLimit()) { + return; + } + if (!shape.hasLimit()) { + addProblem(problems, fixSuggestions, "MISSING_LIMIT", "缺少返回行数限制", "medium", + "问题要求 TopN / 前N / 单个最值对象,但 SQL 没有限制返回行数。", "题目明确只需要前几名或唯一最值对象;如果不限制行数,结果会混入多余记录。", + "SQL 应通过 LIMIT / TOP / FETCH FIRST 控制返回行数。", "当前 SQL 未检测到返回行数限制。", + "query=" + query + "; sql=" + sql, "补齐 LIMIT / TOP / FETCH FIRST,避免把全量结果当成 TopN。"); + recordRuleCheck(ruleChecks, "LIMIT_REQUIRED", "TopN 行数限制校验", "FAILED", + "问题要求限制返回行数,但 SQL 未检测到 LIMIT/TOP/FETCH FIRST。", "query=" + query + "; sql=" + sql); + return; + } + if (shape.limitValueKnown() && intent.expectedLimit() != null + && !intent.expectedLimit().equals(shape.limitValue())) { + String mismatchCode = shape.limitValue() > intent.expectedLimit() ? "LIMIT_TOO_LARGE" : "LIMIT_TOO_SMALL"; + addProblem(problems, fixSuggestions, mismatchCode, "TopN 数量不匹配", "medium", "问题要求的返回条数与 SQL 实际限制条数不一致。", + "题目明确给了 TopN 数量或单个最值对象,限制条数不一致会直接改变结果口径。", "返回条数应与题目要求一致。", + "题目期望 " + intent.expectedLimit() + " 条,但 SQL 当前限制为 " + shape.limitValue() + " 条。", + "query=" + query + "; sql=" + sql, "把 LIMIT/TOP/FETCH FIRST 改成与题目一致的条数。"); + recordRuleCheck(ruleChecks, "LIMIT_MATCH", "TopN 数量匹配校验", "FAILED", + "题目期望 " + intent.expectedLimit() + " 条,但 SQL 实际限制为 " + shape.limitValue() + " 条。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck( + ruleChecks, "LIMIT_MATCH", "TopN 数量匹配校验", "PASSED", intent.expectedLimit() == null + ? "题目要求限制返回行数,SQL 已检测到限制。" : "题目期望 " + intent.expectedLimit() + " 条,SQL 返回条数限制一致。", + "sql=" + sql); + } + + private void evaluateDistinctRule(String query, String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!intent.requiresDistinct()) { + return; + } + if (!shape.hasDistinct()) { + addProblem(problems, fixSuggestions, "MISSING_DISTINCT", "缺少 DISTINCT 去重", "high", + "问题要求去重口径,但 SQL 没看到 DISTINCT。", "题目要求独立用户、去重人数、唯一值等口径;不去重会重复计算。", + "SQL 应使用 SELECT DISTINCT 或 COUNT(DISTINCT ...)。", "当前 SQL 未检测到 DISTINCT。", + "query=" + query + "; sql=" + sql, "把口径改成 SELECT DISTINCT 或 COUNT(DISTINCT ...)。"); + recordRuleCheck(ruleChecks, "DISTINCT_REQUIRED", "去重口径校验", "FAILED", "问题要求去重口径,但 SQL 未检测到 DISTINCT。", + "query=" + query + "; sql=" + sql); + return; + } + recordRuleCheck(ruleChecks, "DISTINCT_REQUIRED", "去重口径校验", "PASSED", "问题要求去重口径,SQL 已检测到 DISTINCT。", + "sql=" + sql); + } + + private void evaluateOrderDirectionRule(String sql, QueryIntent intent, SqlShape shape, + List problems, Set fixSuggestions, List ruleChecks) { + if (!shape.hasOrderBy()) { + return; + } + if (intent.prefersDescending() && shape.orderDirectionKnown() && !shape.orderDescending()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_MISMATCH", "排序方向不匹配", "high", + "问题要求最高 / Top / 最多,但 SQL 排序方向不像降序。", "题目要的是最大值或靠前排名,若排序方向写成 ASC,返回的会是最小值或反向结果。", + "这类问题通常应按目标指标 DESC 排序。", "当前 SQL 的排序方向与题目诉求不一致。", "sql=" + sql, "把排序方向改成 DESC,并确认排序指标是否正确。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION", "排序方向校验", "FAILED", "题目要求高到低/最多/Top,但 SQL 当前排序方向不是 DESC。", + "sql=" + sql); + return; + } + if (intent.prefersAscending() && shape.orderDirectionKnown() && shape.orderDescending()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_MISMATCH", "排序方向不匹配", "high", + "问题要求最低 / 最少 / 最小,但 SQL 排序方向不像升序。", "题目要的是最小值或最低排名,若排序方向写成 DESC,返回的会是最大值或反向结果。", + "这类问题通常应按目标指标 ASC 排序。", "当前 SQL 的排序方向与题目诉求不一致。", "sql=" + sql, "把排序方向改成 ASC,并确认排序指标是否正确。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION", "排序方向校验", "FAILED", "题目要求低到高/最少/最小,但 SQL 当前排序方向不是 ASC。", + "sql=" + sql); + return; + } + if ((intent.prefersDescending() || intent.prefersAscending()) && !shape.orderDirectionKnown()) { + addProblem(problems, fixSuggestions, "ORDER_DIRECTION_AMBIGUOUS", "排序方向不明确", "medium", + "问题对排序方向有明确诉求,但 SQL 的 ORDER BY 没有写明 ASC / DESC。", "有些数据库默认升序,但不应该依赖默认行为承载业务口径,否则最值问题很容易答反。", + "ORDER BY 应显式写明 ASC 或 DESC。", "当前 SQL 虽然有 ORDER BY,但没有检测到明确排序方向。", "sql=" + sql, + "显式补上 ASC 或 DESC,不要依赖数据库默认排序方向。"); + recordRuleCheck(ruleChecks, "ORDER_DIRECTION_EXPLICIT", "显式排序方向校验", "FAILED", + "题目存在明确最值方向,但 SQL 的 ORDER BY 未显式写出 ASC/DESC。", "sql=" + sql); + return; + } + if (intent.prefersDescending() || intent.prefersAscending()) { + recordRuleCheck(ruleChecks, "ORDER_DIRECTION_EXPLICIT", "显式排序方向校验", "PASSED", "ORDER BY 已显式声明排序方向。", + "sql=" + sql); + } + } + + private void evaluateHumanFeedbackRule(String query, String sql, String humanFeedbackContent, + HumanFeedbackConstraint feedbackConstraint, List problems, Set fixSuggestions, + List ruleChecks) { + if (!feedbackConstraint.hasConstraints()) { + return; + } + String normalizedSql = StringUtils.trimToEmpty(sql).toLowerCase(Locale.ROOT); + List feedbackProblems = new ArrayList<>(); + + if (!feedbackConstraint.requiredStatusTokens().isEmpty()) { + boolean matchedRequiredStatus = feedbackConstraint.requiredStatusTokens() + .stream() + .anyMatch(token -> sqlContainsStatusToken(normalizedSql, token)); + if (!matchedRequiredStatus) { + feedbackProblems.add("未看到与人工反馈一致的状态过滤条件"); + addProblem(problems, fixSuggestions, "MISSING_CONFIRMED_STATUS_FILTER", "缺少人工反馈确认的状态过滤", "high", + "用户已经通过人工反馈明确了状态口径,但 SQL 里没有体现该约束。", "人工反馈属于已确认条件;如果 SQL 没落实这些条件,最终结果会与用户确认的口径不一致。", + "SQL 应显式体现用户确认过的状态范围或订单口径。", "当前 SQL 未检测到与人工反馈一致的状态条件。", + "query=" + query + "; feedback=" + humanFeedbackContent + "; sql=" + sql, + "把人工反馈里确认过的状态过滤条件补进 WHERE,例如只统计 completed / paid 等已确认状态。"); + } + } + + if (!feedbackConstraint.excludedStatusTokens().isEmpty()) { + boolean mentionsExcludedStatus = feedbackConstraint.excludedStatusTokens() + .stream() + .anyMatch(token -> sqlContainsStatusToken(normalizedSql, token)); + boolean hasStatusPredicate = STATUS_COLUMN_PATTERN.matcher(normalizedSql).find() + || feedbackConstraint.requiredStatusTokens() + .stream() + .anyMatch(token -> sqlContainsStatusToken(normalizedSql, token)); + boolean hasNegativeStatusPredicate = NEGATIVE_STATUS_OPERATOR_PATTERN.matcher(normalizedSql).find(); + if (!hasStatusPredicate && !mentionsExcludedStatus) { + feedbackProblems.add("未看到用于落实人工反馈排除条件的状态过滤"); + addProblem(problems, fixSuggestions, "MISSING_CONFIRMED_STATUS_EXCLUSION", "缺少人工反馈确认的排除条件", "high", + "用户已经通过人工反馈确认要排除某些状态,但 SQL 里没有看到对应的过滤条件。", "像“不含退款”“排除取消单”这类反馈会直接改变统计口径;如果 SQL 不落实,结果会偏大或口径错误。", + "SQL 应显式体现这些排除条件,或通过更窄的已确认状态集合覆盖它们。", "当前 SQL 未检测到相关状态过滤。", + "query=" + query + "; feedback=" + humanFeedbackContent + "; sql=" + sql, + "把人工反馈里确认的排除条件补进 WHERE,例如排除 refund / cancelled 等状态。"); + } + else if (mentionsExcludedStatus && !hasNegativeStatusPredicate + && feedbackConstraint.requiredStatusTokens().isEmpty()) { + feedbackProblems.add("SQL 提到了应排除的状态,但没有看到明确排除写法"); + addProblem(problems, fixSuggestions, "CONFIRMED_STATUS_EXCLUSION_MISMATCH", "人工反馈排除条件未落实", "high", + "人工反馈要求排除某些状态,但 SQL 里虽然出现了这些状态词,却没有看到明确的排除写法。", + "如果只是把 refund / cancelled 放进正向条件里,结果会和用户确认的口径相反。", + "这些状态应通过 <> / != / NOT IN 等方式排除,或通过更窄的正向状态集间接排除。", "当前 SQL 提到了应排除的状态,但没有检测到明确排除条件。", + "query=" + query + "; feedback=" + humanFeedbackContent + "; sql=" + sql, + "把这些状态改成显式排除条件,或改成更精确的正向状态集合。"); + } + } + + if (feedbackProblems.isEmpty()) { + recordRuleCheck(ruleChecks, "CONFIRMED_FEEDBACK_CONSTRAINTS", "人工反馈一致性校验", "PASSED", + "SQL 已体现当前人工反馈中的显式状态口径约束。", "feedback=" + humanFeedbackContent); + return; + } + recordRuleCheck(ruleChecks, "CONFIRMED_FEEDBACK_CONSTRAINTS", "人工反馈一致性校验", "FAILED", + String.join(";", feedbackProblems), "feedback=" + humanFeedbackContent + "; sql=" + sql); + } + + private void addProblem(List problems, Set fixSuggestions, String code, String title, + String severity, String message, String why, String expected, String actual, String evidence, + String fixSuggestion) { + problems.add(SqlGuardProblem.builder() + .code(code) + .title(title) + .severity(severity) + .message(message) + .why(why) + .expected(expected) + .actual(actual) + .evidence(evidence) + .repairHint(fixSuggestion) + .build()); + fixSuggestions.add(fixSuggestion); + } + + private void recordRuleCheck(List ruleChecks, String code, String title, String status, + String detail, String evidence) { + ruleChecks.add(SqlGuardRuleCheck.builder() + .code(code) + .title(title) + .status(status) + .detail(detail) + .evidence(evidence) + .build()); + } + + private boolean isBlockingSeverity(String severity) { + return "high".equalsIgnoreCase(severity) || "medium".equalsIgnoreCase(severity); + } + + private String mergeIntentSource(String query, String humanFeedbackContent) { + if (StringUtils.isBlank(humanFeedbackContent)) { + return query; + } + return query + "\n" + humanFeedbackContent; + } + + private HumanFeedbackConstraint analyzeHumanFeedbackConstraint(String humanFeedbackContent) { + if (StringUtils.isBlank(humanFeedbackContent)) { + return HumanFeedbackConstraint.empty(); + } + String feedback = humanFeedbackContent.trim(); + String normalizedFeedback = feedback.toLowerCase(Locale.ROOT); + Set requiredStatusTokens = new LinkedHashSet<>(); + Set excludedStatusTokens = new LinkedHashSet<>(); + + collectRequiredStatusTokens(feedback, normalizedFeedback, requiredStatusTokens); + collectExcludedStatusTokens(feedback, normalizedFeedback, excludedStatusTokens); + return new HumanFeedbackConstraint(feedback, Set.copyOf(requiredStatusTokens), + Set.copyOf(excludedStatusTokens)); + } + + private void collectRequiredStatusTokens(String feedback, String normalizedFeedback, Set target) { + if (containsAny(feedback, "已完成", "完成订单") || containsAny(normalizedFeedback, "completed", "complete")) { + target.add("completed"); + } + if (containsAny(feedback, "已支付", "支付成功") || containsAny(normalizedFeedback, "paid", "payment_success")) { + target.add("paid"); + } + if (containsAny(feedback, "待支付", "未支付", "待处理") || containsAny(normalizedFeedback, "pending", "unpaid")) { + target.add("pending"); + } + if (containsAny(feedback, "已取消", "取消单") || containsAny(normalizedFeedback, "cancelled", "canceled")) { + target.add("cancelled"); + } + if (containsAny(feedback, "退款", "已退款") || containsAny(normalizedFeedback, "refund", "refunded")) { + if (!containsAny(feedback, "不含退款", "不包含退款", "排除退款", "剔除退款") + && !containsAny(normalizedFeedback, "exclude refund", "without refund", "not refunded")) { + target.add("refund"); + } + } + } + + private void collectExcludedStatusTokens(String feedback, String normalizedFeedback, Set target) { + if (containsAny(feedback, "不含退款", "不包含退款", "排除退款", "剔除退款", "不看退款") + || containsAny(normalizedFeedback, "exclude refund", "without refund", "not refunded")) { + target.add("refund"); + } + if (containsAny(feedback, "不含取消", "不包含取消", "排除取消", "剔除取消", "不看取消") || containsAny(normalizedFeedback, + "exclude cancel", "without cancel", "not cancelled", "not canceled")) { + target.add("cancelled"); + } + } + + private boolean sqlContainsStatusToken(String normalizedSql, String token) { + return switch (token) { + case "completed" -> containsAny(normalizedSql, "'completed'", "\"completed\"", " completed ", "'complete'", + "已完成", "'success'", "\"success\""); + case "paid" -> containsAny(normalizedSql, "'paid'", "\"paid\"", " paid ", "已支付", "payment_success"); + case "pending" -> + containsAny(normalizedSql, "'pending'", "\"pending\"", " pending ", "待支付", "未支付", "'unpaid'"); + case "cancelled" -> containsAny(normalizedSql, "'cancelled'", "\"cancelled\"", "'canceled'", "\"canceled\"", + " cancelled ", " canceled ", "已取消", "取消"); + case "refund" -> containsAny(normalizedSql, "'refund'", "\"refund\"", "'refunded'", "\"refunded\"", + " refund ", " refunded ", "退款"); + default -> containsAny(normalizedSql, token); + }; + } + + private QueryIntent analyzeQueryIntent(String query) { + String normalized = query.toLowerCase(Locale.ROOT); + boolean requiresTrend = containsAny(query, "趋势", "走势图", "按天", "按周", "按月", "按年", "daily", "weekly", "monthly", + "trend", "over time", "环比", "同比"); + boolean requiresGrouping = requiresTrend + || containsAny(query, "按", "按照", "每个", "各", "分组", "group by", "维度", "分类", "分城市", "分地区", "分品类", "分渠道"); + boolean requiresTimeFilter = containsAny(query, "今天", "昨日", "昨天", "本周", "上周", "本月", "上月", "本季度", "上季度", "今年", + "去年", "近", "最近", "最近的", "latest", "recent", "last ", "past ", "today", "yesterday", "this month", + "this year", "202", "2024", "2025", "2026", "Q1", "Q2", "Q3", "Q4"); + boolean requiresOrdering = requiresTrend || containsAny(normalized, "top ", "rank", "ranking", "highest", + "lowest", "best", "worst", "most", "least") + || containsAny(query, "排名", "排行", "最高", "最低", "最多", "最少", "前", "后"); + boolean prefersDescending = containsAny(normalized, "top ", "highest", "best", "most", "largest") + || containsAny(query, "最高", "最多", "最大", "前"); + boolean prefersAscending = containsAny(normalized, "lowest", "least", "smallest", "worst") + || containsAny(query, "最低", "最少", "最小"); + boolean explicitDescendingDirection = containsAny(normalized, " desc", "descending") + || containsAny(query, "降序", "从高到低"); + boolean explicitAscendingDirection = containsAny(normalized, " asc", "ascending") + || containsAny(query, "升序", "从低到高"); + requiresOrdering = requiresOrdering || explicitDescendingDirection || explicitAscendingDirection; + prefersDescending = prefersDescending || explicitDescendingDirection; + prefersAscending = prefersAscending || explicitAscendingDirection; + Integer expectedLimit = extractExpectedLimit(query, prefersDescending, prefersAscending); + boolean requiresLimit = expectedLimit != null; + boolean requiresDistinct = containsAny(normalized, "distinct", "deduplicate", "unique", "uv") + || containsAny(query, "去重", "独立用户", "唯一"); + boolean requiresAggregation = requiresTrend || requiresDistinct + || containsAny(normalized, "count", "sum", "avg", "average", "total", "amount", "sales", "revenue") + || containsAny(query, "数量", "总数", "总额", "金额", "销量", "销售额", "订单数", "人数", "平均", "占比", "比例", "贡献", "多少"); + return new QueryIntent(requiresAggregation, requiresGrouping, requiresTimeFilter, requiresOrdering, + requiresLimit, requiresDistinct, requiresTrend, prefersDescending, prefersAscending, expectedLimit); + } + + private SqlShape analyzeSqlShape(Statement statement, String sql, SqlGuardCheckRequest request) { + String normalizedSql = StringUtils.trimToEmpty(sql).toLowerCase(Locale.ROOT); + Set knownTimeColumns = extractKnownTimeColumns(request); + List usedTables = extractReferencedTables(statement); + List usedMetrics = extractUsedMetrics(sql); + boolean hasAggregation = AGGREGATE_PATTERN.matcher(sql).find(); + boolean hasGroupBy = GROUP_BY_PATTERN.matcher(normalizedSql).find(); + boolean hasOrderBy = ORDER_BY_PATTERN.matcher(normalizedSql).find(); + boolean hasLimit = LIMIT_PATTERN.matcher(normalizedSql).find() || TOP_PATTERN.matcher(normalizedSql).find() + || FETCH_FIRST_PATTERN.matcher(normalizedSql).find(); + boolean hasDistinct = DISTINCT_PATTERN.matcher(normalizedSql).find(); + boolean hasTimePredicate = detectTimePredicate(normalizedSql, knownTimeColumns); + boolean hasTimeBucket = detectTimeBucket(normalizedSql, knownTimeColumns); + boolean orderDescending = DESC_PATTERN.matcher(normalizedSql).find(); + boolean orderDirectionKnown = DESC_PATTERN.matcher(normalizedSql).find() + || ASC_PATTERN.matcher(normalizedSql).find(); + Integer limitValue = extractSqlLimitValue(normalizedSql); + boolean limitValueKnown = limitValue != null; + return new SqlShape(List.copyOf(usedTables), List.copyOf(usedMetrics), hasAggregation, hasGroupBy, hasOrderBy, + hasLimit, hasDistinct, hasTimePredicate, hasTimeBucket, orderDescending, orderDirectionKnown, + limitValueKnown, limitValue); + } + + private Integer extractSqlLimitValue(String normalizedSql) { + Integer limitValue = extractFirstInt(SQL_LIMIT_VALUE_PATTERN, normalizedSql); + if (limitValue != null) { + return limitValue; + } + limitValue = extractFirstInt(SQL_TOP_VALUE_PATTERN, normalizedSql); + if (limitValue != null) { + return limitValue; + } + return extractFirstInt(SQL_FETCH_FIRST_VALUE_PATTERN, normalizedSql); + } + + private boolean detectTimePredicate(String normalizedSql, Set knownTimeColumns) { + if (!WHERE_PATTERN.matcher(normalizedSql).find()) { + return false; + } + if (DATE_LITERAL_PATTERN.matcher(normalizedSql).find() || TIME_FUNCTION_PATTERN.matcher(normalizedSql).find()) { + return true; + } + if (containsAny(normalizedSql, " created_at ", " create_time ", " updated_at ", " update_time ", " order_date ", + " biz_date ", " stat_date ", " date ", " time ", " month ", " year ", " day ")) { + return true; + } + return knownTimeColumns.stream().anyMatch(column -> normalizedSql.contains(column.toLowerCase(Locale.ROOT))); + } + + private boolean detectTimeBucket(String normalizedSql, Set knownTimeColumns) { + if (!GROUP_BY_PATTERN.matcher(normalizedSql).find()) { + return false; + } + if (containsAny(normalizedSql, "date_trunc(", "date(", "strftime(", "to_date(", "extract(", "year(", "month(", + "day(")) { + return true; + } + if (containsAny(normalizedSql, " by day", " by month", " by week", " by year")) { + return true; + } + return knownTimeColumns.stream().anyMatch(column -> normalizedSql.contains(column.toLowerCase(Locale.ROOT))); + } + + private Set extractKnownTimeColumns(SqlGuardCheckRequest request) { + return new LinkedHashSet<>(); + } + + private boolean isLikelyTimeColumn(String value) { + String normalized = StringUtils.trimToEmpty(value).toLowerCase(Locale.ROOT); + return containsAny(normalized, "date", "time", "day", "week", "month", "year", "created", "updated", "dt", + "biz_date", "stat_date", "order_date"); + } + + private List extractReferencedTables(Statement statement) { + try { + return new TablesNamesFinder().getTableList(statement) + .stream() + .filter(StringUtils::isNotBlank) + .map(String::trim) + .distinct() + .toList(); + } + catch (Exception ex) { + return List.of(); + } + } + + private List extractUsedMetrics(String sql) { + Set metrics = new LinkedHashSet<>(); + Matcher matcher = AGGREGATE_PATTERN.matcher(StringUtils.defaultString(sql)); + while (matcher.find()) { + String alias = matcher.group(3); + if (StringUtils.isNotBlank(alias)) { + metrics.add(alias.trim()); + continue; + } + String functionName = Objects.toString(matcher.group(1), "").toUpperCase(Locale.ROOT); + String argument = Objects.toString(matcher.group(2), "").trim(); + metrics.add(functionName + "(" + argument + ")"); + } + return List.copyOf(metrics); + } + + private Statement parseSingleSelectStatement(String sql) { + String normalizedSql = stripTrailingSemicolons(sql); + if (normalizedSql.isEmpty()) { + throw new IllegalArgumentException("SQL 不能为空"); + } + try { + List statements = CCJSqlParserUtil.parseStatements(normalizedSql).getStatements(); + if (statements == null || statements.isEmpty()) { + throw new IllegalArgumentException("SQL 不能为空"); + } + if (statements.size() > 1) { + throw new IllegalArgumentException("仅支持单条 SELECT / WITH 查询"); + } + Statement statement = statements.get(0); + if (!(statement instanceof Select)) { + throw new IllegalArgumentException("sql_guard.check 仅校验 SELECT / WITH 查询"); + } + return statement; + } + catch (IllegalArgumentException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalArgumentException("SQL 解析失败,请检查语法后重试", ex); + } + } + + private String stripTrailingSemicolons(String sql) { + String trimmed = StringUtils.trimToEmpty(sql); + while (trimmed.endsWith(";")) { + trimmed = trimmed.substring(0, trimmed.length() - 1).trim(); + } + return trimmed; + } + + private String buildIntentExplanation(QueryIntent intent) { + List fragments = new ArrayList<>(); + if (intent.requiresAggregation()) { + fragments.add("问题包含聚合指标诉求"); + } + if (intent.requiresGrouping()) { + fragments.add("问题包含按维度拆分诉求"); + } + if (intent.requiresTimeFilter()) { + fragments.add("问题包含明确时间窗口"); + } + if (intent.requiresTrend()) { + fragments.add("问题包含趋势或时间序列分析"); + } + if (intent.requiresOrdering()) { + fragments.add("问题包含排序或排名要求"); + } + if (intent.requiresLimit()) { + fragments.add("问题包含 TopN / Top1 行数限制"); + } + if (intent.requiresDistinct()) { + fragments.add("问题包含去重口径"); + } + if (fragments.isEmpty()) { + return "当前规则没有识别到强约束口径,主要执行基础 SQL 结构检查。"; + } + return String.join(";", fragments) + "。"; + } + + private boolean containsAny(String value, String... needles) { + if (value == null) { + return false; + } + for (String needle : needles) { + if (needle != null && value.contains(needle)) { + return true; + } + } + return false; + } + + private Integer extractExpectedLimit(String query, boolean prefersDescending, boolean prefersAscending) { + String safeQuery = StringUtils.defaultString(query); + Matcher matcher = TOP_N_QUERY_PATTERN.matcher(safeQuery); + if (matcher.find()) { + for (int index = 1; index <= matcher.groupCount(); index++) { + String group = matcher.group(index); + if (StringUtils.isNotBlank(group)) { + try { + return Integer.parseInt(group.trim()); + } + catch (NumberFormatException ex) { + return null; + } + } + } + } + boolean asksSingleExtreme = containsAny(safeQuery, "最多", "最高", "最低", "最少", "第一", "首位", "top1", "top 1", + "highest", "lowest", "most", "least", "first"); + if (asksSingleExtreme && (prefersDescending || prefersAscending)) { + return 1; + } + boolean asksSingleTarget = containsAny(safeQuery, "哪个", "哪位", "哪一个", "谁"); + if (asksSingleTarget && (prefersDescending || prefersAscending)) { + return 1; + } + return null; + } + + private Integer extractFirstInt(Pattern pattern, String value) { + Matcher matcher = pattern.matcher(StringUtils.defaultString(value)); + if (!matcher.find()) { + return null; + } + String group = matcher.group(1); + if (StringUtils.isBlank(group)) { + return null; + } + try { + return Integer.parseInt(group.trim()); + } + catch (NumberFormatException ex) { + return null; + } + } + + private record ProfileContext(AgentDatasource agentDatasource, Datasource datasource, DbConfigBO dbConfig, + Accessor accessor, List visibleTables, Map> visibleTablesByName, + Map> visibleTablesByLeafName, Map> visibleColumnsByTable, + Map> visibleColumnNameSetByTable, Set columnRestrictedTables) { + } + + private record QueryIntent(boolean requiresAggregation, boolean requiresGrouping, boolean requiresTimeFilter, + boolean requiresOrdering, boolean requiresLimit, boolean requiresDistinct, boolean requiresTrend, + boolean prefersDescending, boolean prefersAscending, Integer expectedLimit) { + } + + private record HumanFeedbackConstraint(String feedbackContent, Set requiredStatusTokens, + Set excludedStatusTokens) { + + private static HumanFeedbackConstraint empty() { + return new HumanFeedbackConstraint("", Set.of(), Set.of()); + } + + private boolean hasConstraints() { + return StringUtils.isNotBlank(feedbackContent) + && (!requiredStatusTokens.isEmpty() || !excludedStatusTokens.isEmpty()); + } + } + + private record SqlShape(List usedTables, List usedMetrics, boolean hasAggregation, + boolean hasGroupBy, boolean hasOrderBy, boolean hasLimit, boolean hasDistinct, boolean hasTimePredicate, + boolean hasTimeBucket, boolean orderDescending, boolean orderDirectionKnown, boolean limitValueKnown, + Integer limitValue) { + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java similarity index 81% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java index 181417a18..7c2d7709e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/GraphNodeResponse.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/agentscope/vo/AgentResponse.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.dataagent.agentscope.vo; import com.alibaba.cloud.ai.dataagent.enums.TextType; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -25,7 +26,7 @@ @AllArgsConstructor @NoArgsConstructor @Builder -public class GraphNodeResponse { +public class AgentResponse { private String agentId; @@ -37,14 +38,16 @@ public class GraphNodeResponse { private String text; + private Map metadata; + @Builder.Default private boolean error = false; @Builder.Default private boolean complete = false; - public static GraphNodeResponse error(String agentId, String threadId, String text) { - return GraphNodeResponse.builder() + public static AgentResponse error(String agentId, String threadId, String text) { + return AgentResponse.builder() .agentId(agentId) .threadId(threadId) .text(text) @@ -53,8 +56,8 @@ public static GraphNodeResponse error(String agentId, String threadId, String te .build(); } - public static GraphNodeResponse complete(String agentId, String threadId) { - return GraphNodeResponse.builder() + public static AgentResponse complete(String agentId, String threadId) { + return AgentResponse.builder() .agentId(agentId) .threadId(threadId) .complete(true) diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java new file mode 100644 index 000000000..da2d5a3de --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/AgentScopeTracingConfiguration.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.config; + +import com.alibaba.cloud.ai.dataagent.properties.AgentScopeObservabilityProperties; +import io.agentscope.core.tracing.TracerRegistry; +import io.agentscope.core.tracing.telemetry.TelemetryTracer; +import io.opentelemetry.api.trace.Tracer; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; + +@Slf4j +@Configuration +@RequiredArgsConstructor +@EnableConfigurationProperties(AgentScopeObservabilityProperties.class) +public class AgentScopeTracingConfiguration implements SmartInitializingSingleton { + + private final AgentScopeObservabilityProperties properties; + + private final OpenTelemetryConfig openTelemetryConfig; + + @Qualifier("langfuseTracer") + private final Tracer langfuseTracer; + + @Qualifier("agentScopeLocalTracer") + private final Tracer agentScopeLocalTracer; + + @Bean("agentScopeTracer") + @Primary + public Tracer agentScopeTracer() { + return selectTracer(); + } + + @Override + public void afterSingletonsInstantiated() { + if (!properties.isEnabled()) { + log.info("AgentScope native tracing is disabled by configuration."); + return; + } + + TracerRegistry.register(TelemetryTracer.builder().tracer(selectTracer()).build()); + if (properties.isUseLangfuseTracer() && openTelemetryConfig.isEnabled()) { + log.info( + "AgentScope native tracing initialized with local trace cache and Langfuse OpenTelemetry exporter."); + return; + } + + log.info("AgentScope native tracing initialized with local trace cache only."); + } + + private Tracer selectTracer() { + if (properties.isUseLangfuseTracer() && openTelemetryConfig.isEnabled()) { + return langfuseTracer; + } + return agentScopeLocalTracer; + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java index 06e9a2038..2cf480529 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/DataAgentConfiguration.java @@ -15,8 +15,8 @@ */ package com.alibaba.cloud.ai.dataagent.config; -import com.alibaba.cloud.ai.dataagent.properties.CodeExecutorProperties; import com.alibaba.cloud.ai.dataagent.properties.AgentSkillProperties; +import com.alibaba.cloud.ai.dataagent.properties.CodeExecutorProperties; import com.alibaba.cloud.ai.dataagent.properties.DataAgentProperties; import com.alibaba.cloud.ai.dataagent.properties.FileStorageProperties; import com.alibaba.cloud.ai.dataagent.service.vectorstore.SimpleVectorStoreInitialization; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java index 828bb7ad8..44ad452fb 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/config/OpenTelemetryConfig.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.dataagent.config; import com.alibaba.cloud.ai.dataagent.constant.Constant; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; @@ -24,10 +25,13 @@ import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.resources.Resource; import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.trace.SdkTracerProviderBuilder; import io.opentelemetry.sdk.trace.export.BatchSpanProcessor; +import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; import jakarta.annotation.PreDestroy; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -57,50 +61,77 @@ public class OpenTelemetryConfig { private String secretKey; - private SdkTracerProvider tracerProvider; + private SdkTracerProvider langfuseTracerProvider; - @Bean - public OpenTelemetry openTelemetry() { - if (!enabled) { - return OpenTelemetry.noop(); + private SdkTracerProvider localTraceTracerProvider; + + @Bean("langfuseOpenTelemetry") + public OpenTelemetry langfuseOpenTelemetry(SessionTraceStore sessionTraceStore) { + langfuseTracerProvider = buildTracerProvider(sessionTraceStore, enabled); + OpenTelemetrySdk openTelemetrySdk = OpenTelemetrySdk.builder() + .setTracerProvider(langfuseTracerProvider) + .build(); + + if (enabled) { + log.info("OpenTelemetry initialized with local trace cache and Langfuse OTLP HTTP exporter."); + } + else { + log.info("OpenTelemetry initialized with local trace cache only."); } - String auth = publicKey + ":" + secretKey; - String encodedAuth = Base64.getEncoder().encodeToString(auth.getBytes(StandardCharsets.UTF_8)); + return openTelemetrySdk; + } - OtlpHttpSpanExporter spanExporter = OtlpHttpSpanExporter.builder() - .setEndpoint(host + "/api/public/otel/v1/traces") - .addHeader("Authorization", "Basic " + encodedAuth) - .setTimeout(10, TimeUnit.SECONDS) - .build(); + @Bean("agentScopeLocalOpenTelemetry") + public OpenTelemetry agentScopeLocalOpenTelemetry(SessionTraceStore sessionTraceStore) { + localTraceTracerProvider = buildTracerProvider(sessionTraceStore, false); + return OpenTelemetrySdk.builder().setTracerProvider(localTraceTracerProvider).build(); + } + private SdkTracerProvider buildTracerProvider(SessionTraceStore sessionTraceStore, boolean withLangfuseExporter) { Resource resource = Resource.getDefault() .merge(Resource.create(Attributes.of(AttributeKey.stringKey("service.name"), SERVICE_NAME))); - tracerProvider = SdkTracerProvider.builder() - .addSpanProcessor(BatchSpanProcessor.builder(spanExporter) - .setScheduleDelay(1, TimeUnit.SECONDS) - .setMaxExportBatchSize(100) - .build()) - .setResource(resource) - .build(); + SdkTracerProviderBuilder builder = SdkTracerProvider.builder() + .addSpanProcessor(SimpleSpanProcessor.create(sessionTraceStore)) + .setResource(resource); - OpenTelemetrySdk openTelemetrySdk = OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + if (withLangfuseExporter) { + String auth = publicKey + ":" + secretKey; + String encodedAuth = Base64.getEncoder().encodeToString(auth.getBytes(StandardCharsets.UTF_8)); - log.info("OpenTelemetry initialized with Langfuse OTLP HTTP exporter"); + OtlpHttpSpanExporter spanExporter = OtlpHttpSpanExporter.builder() + .setEndpoint(host + "/api/public/otel/v1/traces") + .addHeader("Authorization", "Basic " + encodedAuth) + .setTimeout(10, TimeUnit.SECONDS) + .build(); - return openTelemetrySdk; + builder.addSpanProcessor(BatchSpanProcessor.builder(spanExporter) + .setScheduleDelay(1, TimeUnit.SECONDS) + .setMaxExportBatchSize(100) + .build()); + } + + return builder.build(); } - @Bean - public Tracer langfuseTracer(OpenTelemetry openTelemetry) { + @Bean("langfuseTracer") + public Tracer langfuseTracer(@Qualifier("langfuseOpenTelemetry") OpenTelemetry openTelemetry) { return openTelemetry.getTracer(SERVICE_NAME); } + @Bean("agentScopeLocalTracer") + public Tracer agentScopeLocalTracer(@Qualifier("agentScopeLocalOpenTelemetry") OpenTelemetry openTelemetry) { + return openTelemetry.getTracer(SERVICE_NAME + ".agentscope.local"); + } + @PreDestroy public void shutdown() { - if (tracerProvider != null) { - tracerProvider.close(); + if (langfuseTracerProvider != null) { + langfuseTracerProvider.close(); + } + if (localTraceTracerProvider != null) { + localTraceTracerProvider.close(); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java index 2ef0ebe66..9f0f16ab6 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/accessor/AbstractAccessor.java @@ -29,6 +29,8 @@ import com.alibaba.cloud.ai.dataagent.bo.DbConfigBO; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import java.sql.Connection; import java.util.List; @@ -63,7 +65,8 @@ public T accessDb(DbConfigBO dbConfig, String method, DbQueryParameter param case "showColumns": return (T) ddlExecutor.showColumns(connection, param.getSchema(), param.getTable()); case "showForeignKeys": - return (T) ddlExecutor.showForeignKeys(connection, param.getSchema(), param.getTables()); + return (T) ddlExecutor.showForeignKeys(connection, param.getSchema(), + resolveTablesForForeignKeys(connection, ddlExecutor, param)); case "sampleColumn": return (T) ddlExecutor.sampleColumn(connection, param.getSchema(), param.getTable(), param.getColumn()); @@ -122,4 +125,19 @@ public Connection getConnection(DbConfigBO config) { return this.dbConnectionPool.getConnection(config); } + private List resolveTablesForForeignKeys(Connection connection, AbstractJdbcDdl ddlExecutor, + DbQueryParameter param) { + if (param == null) { + return List.of(); + } + if (!CollectionUtils.isEmpty(param.getTables())) { + return param.getTables(); + } + List tables = ddlExecutor.showTables(connection, param.getSchema(), param.getTablePattern()); + if (CollectionUtils.isEmpty(tables)) { + return List.of(); + } + return tables.stream().map(TableInfoBO::getName).filter(StringUtils::hasText).distinct().toList(); + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java index acd846b8f..576d8bd83 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/impls/hive/HiveJdbcConnectionPool.java @@ -20,10 +20,12 @@ import com.alibaba.cloud.ai.dataagent.enums.BizDataSourceTypeEnum; import com.alibaba.cloud.ai.dataagent.enums.ErrorCodeEnum; import com.alibaba.druid.pool.DruidDataSourceFactory; +import com.alibaba.druid.pool.DruidDataSource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import javax.sql.DataSource; +import java.sql.DriverManager; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -93,7 +95,13 @@ public DataSource createdDataSource(String url, String username, String password log.info("Creating Hive DataSource with custom configuration"); String driver = getDriver(); Map props = new HiveDruidProperties(driver, url, username, password, "stat").toMap(); - return DruidDataSourceFactory.createDataSource(props); + DruidDataSource dataSource = (DruidDataSource) DruidDataSourceFactory.createDataSource(props); + dataSource.setInitialSize(0); + dataSource.setMinIdle(0); + dataSource.setBreakAfterAcquireFailure(true); + dataSource.setConnectionErrorRetryAttempts(2); + dataSource.setTestWhileIdle(false); + return dataSource; } private static final class HiveDruidProperties { @@ -123,14 +131,14 @@ private Map toMap() { props.put(DruidDataSourceFactory.PROP_USERNAME, this.username); props.put(DruidDataSourceFactory.PROP_PASSWORD, this.password); props.put(DruidDataSourceFactory.PROP_FILTERS, this.filters); - props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "5"); - props.put(DruidDataSourceFactory.PROP_MINIDLE, "5"); + props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "0"); + props.put(DruidDataSourceFactory.PROP_MINIDLE, "0"); props.put(DruidDataSourceFactory.PROP_MAXACTIVE, "20"); props.put(DruidDataSourceFactory.PROP_MAXWAIT, "60000"); props.put(DruidDataSourceFactory.PROP_TIMEBETWEENEVICTIONRUNSMILLIS, "60000"); props.put(DruidDataSourceFactory.PROP_MINEVICTABLEIDLETIMEMILLIS, "300000"); props.put(DruidDataSourceFactory.PROP_VALIDATIONQUERY, "SELECT 1"); - props.put(DruidDataSourceFactory.PROP_TESTWHILEIDLE, "true"); + props.put(DruidDataSourceFactory.PROP_TESTWHILEIDLE, "false"); props.put(DruidDataSourceFactory.PROP_TESTONBORROW, "false"); props.put(DruidDataSourceFactory.PROP_TESTONRETURN, "false"); return props; @@ -141,7 +149,8 @@ private Map toMap() { @Override public ErrorCodeEnum ping(DbConfigBO config) { log.info("Hive ping method called, url: {}", config.getUrl()); - try (Connection connection = getConnection(config); Statement stmt = connection.createStatement()) { + try (Connection connection = DriverManager.getConnection(config.getUrl(), config.getUsername(), + config.getPassword()); Statement stmt = connection.createStatement()) { log.info("Hive connection obtained, executing SELECT 1"); ResultSet rs = stmt.executeQuery("SELECT 1"); if (rs.next()) { diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java index c720a1baa..4f6aba33f 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/AbstractDBConnectionPool.java @@ -117,6 +117,7 @@ public Connection getConnection(DbConfigBO config) { log.warn("Attempt {} to get database connection failed: {}", attempt, e.getMessage()); if (attempt == maxRetries) { + evict(config); log.error("Failed to get database connection after {} attempts, URL: {}", maxRetries, jdbcUrl, e); throw new RuntimeException("Failed to get database connection after " + maxRetries + " attempts", e); @@ -141,15 +142,27 @@ public Connection getConnection(DbConfigBO config) { * @param password the database password * @return the cache key */ - private String generateCacheKey(String url, String username, String password) { + protected String generateCacheKey(String url, String username, String password) { return url + "|" + username + "|" + Objects.hashCode(password); } + @Override + public void evict(DbConfigBO config) { + if (config == null || config.getUrl() == null) { + return; + } + String cacheKey = generateCacheKey(config.getUrl(), config.getUsername(), config.getPassword()); + DataSource dataSource = DATA_SOURCE_CACHE.remove(cacheKey); + if (dataSource instanceof DruidDataSource druidDataSource) { + druidDataSource.close(); + } + } + @Override public void close() { DATA_SOURCE_CACHE.values().forEach(dataSource -> { - if (dataSource instanceof DruidDataSource) { - ((DruidDataSource) dataSource).close(); + if (dataSource instanceof DruidDataSource druidDataSource) { + druidDataSource.close(); } }); DATA_SOURCE_CACHE.clear(); @@ -175,20 +188,23 @@ public DataSource createdDataSource(String url, String username, String password props.put(DruidDataSourceFactory.PROP_URL, url); props.put(DruidDataSourceFactory.PROP_USERNAME, username); props.put(DruidDataSourceFactory.PROP_PASSWORD, password); - props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "5"); - props.put(DruidDataSourceFactory.PROP_MINIDLE, "5"); + props.put(DruidDataSourceFactory.PROP_INITIALSIZE, "0"); + props.put(DruidDataSourceFactory.PROP_MINIDLE, "0"); props.put(DruidDataSourceFactory.PROP_MAXACTIVE, "20"); props.put(DruidDataSourceFactory.PROP_MAXWAIT, "10000"); props.put(DruidDataSourceFactory.PROP_TIMEBETWEENEVICTIONRUNSMILLIS, "60000"); props.put(DruidDataSourceFactory.PROP_FILTERS, filters); DruidDataSource dataSource = (DruidDataSource) DruidDataSourceFactory.createDataSource(props); + dataSource.setInitialSize(0); + dataSource.setMinIdle(0); dataSource.setBreakAfterAcquireFailure(Boolean.TRUE); dataSource.setConnectionErrorRetryAttempts(2); + dataSource.setTestWhileIdle(false); // 记录数据源创建信息 log.info( - "Created new DataSource with optimized parameters - InitialSize: 5, MinIdle: 5, MaxActive: 20, MaxWait: 10000ms"); + "Created new DataSource with optimized parameters - InitialSize: 0, MinIdle: 0, MaxActive: 20, MaxWait: 10000ms"); return dataSource; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java index 5876a275a..7380850b6 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/connector/pool/DBConnectionPool.java @@ -41,6 +41,13 @@ public interface DBConnectionPool extends AutoCloseable { */ Connection getConnection(DbConfigBO config); + /** + * Evict cached data source for the given configuration if present. + * @param config the database configuration + */ + default void evict(DbConfigBO config) { + } + boolean supportedDataSourceType(String type); String getConnectionPoolType(); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java index 7d8afd7b7..84fe82250 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentRuntimeConstant.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.dataagent.constant; +import io.agentscope.core.model.ExecutionConfig; import java.time.Duration; /** @@ -22,6 +23,26 @@ */ public final class AgentRuntimeConstant { + public static final int REACT_AGENT_DEFAULT_MAX_ITERS = 10; + + public static final Duration MODEL_EXECUTION_TIMEOUT = Duration.ofMinutes(2); + + public static final int MODEL_MAX_ATTEMPTS = 2; + + public static final Duration TOOL_EXECUTION_TIMEOUT = Duration.ofSeconds(30); + + public static final int TOOL_MAX_ATTEMPTS = 1; + + public static final ExecutionConfig DEFAULT_MODEL_EXECUTION_CONFIG = ExecutionConfig.builder() + .timeout(MODEL_EXECUTION_TIMEOUT) + .maxAttempts(MODEL_MAX_ATTEMPTS) + .build(); + + public static final ExecutionConfig DEFAULT_TOOL_EXECUTION_CONFIG = ExecutionConfig.builder() + .timeout(TOOL_EXECUTION_TIMEOUT) + .maxAttempts(TOOL_MAX_ATTEMPTS) + .build(); + private AgentRuntimeConstant() { } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java new file mode 100644 index 000000000..1ec9049b9 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/constant/AgentSessionConstant.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.constant; + +import java.time.Duration; + +/** + * Agent 会话相关的共享默认配置常量。 + */ +public final class AgentSessionConstant { + + public static final String DEFAULT_SESSION_TITLE = "新会话"; + + public static final int SESSION_TITLE_MAX_LENGTH = 20; + + public static final Duration SESSION_TITLE_GENERATION_TIMEOUT = Duration.ofSeconds(15); + + public static final Duration SESSION_EVENT_HEARTBEAT_INTERVAL = Duration.ofSeconds(2); + + private AgentSessionConstant() { + + } + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java index f9962af1e..81246cb2e 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/AgentDatasourceController.java @@ -15,23 +15,35 @@ */ package com.alibaba.cloud.ai.dataagent.controller; +import com.alibaba.cloud.ai.dataagent.dto.datasource.TableColumnsSelectionDTO; import com.alibaba.cloud.ai.dataagent.dto.datasource.ToggleDatasourceDTO; +import com.alibaba.cloud.ai.dataagent.dto.datasource.UpdateDatasourceColumnsDTO; import com.alibaba.cloud.ai.dataagent.dto.datasource.UpdateDatasourceTablesDTO; import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource; import com.alibaba.cloud.ai.dataagent.exception.InternalServerException; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; +import jakarta.validation.Valid; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.*; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; /** - * Agent Schema Initialization Controller Handles agent's database Schema initialization - * to vector storage + * Agent datasource controller. */ @Slf4j @RestController @@ -42,60 +54,50 @@ public class AgentDatasourceController { private final AgentDatasourceService agentDatasourceService; - /** - * Initialize agent's database Schema to vector storage Corresponds to the "Initialize - * Information Source" function on the frontend - */ @PostMapping("/init") public ApiResponse initSchema(@PathVariable Long agentId) { - // 防止前端恶意请求,dto数据应该在后端获取 try { AgentDatasource agentDatasource = agentDatasourceService.getCurrentAgentDatasource(agentId); log.info("Initializing schema for agent: {}", agentId); - // Extract data source ID and table list from request Integer datasourceId = agentDatasource.getDatasourceId(); List tables = Optional.ofNullable(agentDatasource.getSelectTables()).orElse(List.of()); - // Validate request parameters if (datasourceId == null) { - throw new InvalidInputException("数据源ID不能为空"); + throw new InvalidInputException("datasourceId cannot be null"); } - if (tables.isEmpty()) { - throw new InvalidInputException("表列表不能为空"); + throw new InvalidInputException("tables cannot be empty"); } - // Execute Schema initialization Boolean result = agentDatasourceService.initializeSchemaForAgentWithDatasource(agentId, datasourceId, tables); - - if (result) { + if (Boolean.TRUE.equals(result)) { log.info("Successfully initialized schema for agent: {}, tables: {}", agentId, tables.size()); - return ApiResponse.success("Schema初始化成功"); - } - else { - throw new InternalServerException("Schema初始化失败"); + return ApiResponse.success("Schema initialized successfully"); } + throw new InternalServerException("Schema initialization failed"); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { log.error("Failed to initialize schema for agent: {}", agentId, e); - throw new InternalServerException("Schema初始化失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Schema initialization failed: %s".formatted(e.getMessage())); } } - /** Get list of data sources configured for agent */ @GetMapping public ApiResponse> getAgentDatasource(@PathVariable Long agentId) { try { log.info("Getting datasources for agent: {}", agentId); List datasources = agentDatasourceService.getAgentDatasource(agentId); log.info("Successfully retrieved {} datasources for agent: {}", datasources.size(), agentId); - return ApiResponse.success("操作成功", datasources); + return ApiResponse.success("success", datasources); } catch (Exception e) { log.error("Failed to get datasources for agent: {}", agentId, e); - throw new InvalidInputException("获取数据源失败:%s".formatted(e.getMessage()), List.of()); + throw new InvalidInputException("Failed to get datasources: %s".formatted(e.getMessage()), List.of()); } } @@ -104,59 +106,110 @@ public ApiResponse getActiveAgentDatasource(@PathVariable Long try { log.info("Getting active datasource for agent: {}", agentId); AgentDatasource datasource = agentDatasourceService.getCurrentAgentDatasource(agentId); - return ApiResponse.success("操作成功", datasource); + return ApiResponse.success("success", datasource); } catch (Exception e) { log.error("Failed to get active datasource for agent: {}", agentId, e); - throw new InvalidInputException("获取数据源失败:%s".formatted(e.getMessage()), List.of()); + throw new InvalidInputException("Failed to get active datasource: %s".formatted(e.getMessage()), List.of()); } } - /** Add data source for agent */ @PostMapping("/{datasourceId}") public ApiResponse addDatasourceToAgent(@PathVariable Long agentId, @PathVariable Integer datasourceId) { try { if (datasourceId == null) { - throw new InvalidInputException("数据源ID不能为空"); + throw new InvalidInputException("datasourceId cannot be null"); } - AgentDatasource agentDatasource = agentDatasourceService.addDatasourceToAgent(agentId, datasourceId); - return ApiResponse.success("数据源添加成功", agentDatasource); + return ApiResponse.success("Datasource added successfully", agentDatasource); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { - throw new InternalServerException("数据源添加失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Failed to add datasource: %s".formatted(e.getMessage())); } } - // 更新选择的数据表 @PostMapping("/tables") - public ApiResponse updateDatasourceTables(@PathVariable Long agentId, + public ApiResponse updateDatasourceTables(@PathVariable Long agentId, @RequestBody @Validated UpdateDatasourceTablesDTO dto) { try { dto.setTables(Optional.ofNullable(dto.getTables()).orElse(List.of())); - agentDatasourceService.updateDatasourceTables(agentId, dto.getDatasourceId(), dto.getTables()); - return ApiResponse.success("更新成功"); + AgentDatasource agentDatasource = agentDatasourceService.updateDatasourceTables(agentId, + dto.getDatasourceId(), dto.getTables()); + return ApiResponse.success("Update successful", agentDatasource); + } + catch (IllegalArgumentException e) { + log.warn("Invalid datasource tables update request, agentId={}, datasourceId={}, message={}", agentId, + dto.getDatasourceId(), e.getMessage()); + throw new InvalidInputException(e.getMessage()); } catch (Exception e) { - log.error("Error: ", e); - throw new InternalServerException("更新失败:%s".formatted(e.getMessage())); + log.error("Error updating datasource tables", e); + throw new InternalServerException("Update failed: %s".formatted(e.getMessage())); + } + } + + @PostMapping("/columns") + public ApiResponse updateDatasourceColumns(@PathVariable Long agentId, + @RequestBody @Valid UpdateDatasourceColumnsDTO dto) { + try { + List tableSelections = Optional.ofNullable(dto.getTables()).orElse(List.of()); + Map> columnsByTable = new LinkedHashMap<>(); + for (TableColumnsSelectionDTO tableSelection : tableSelections) { + if (tableSelection == null) { + continue; + } + columnsByTable.put(tableSelection.getTableName(), + Optional.ofNullable(tableSelection.getColumns()).orElse(List.of())); + } + AgentDatasource agentDatasource = agentDatasourceService.updateDatasourceColumns(agentId, + dto.getDatasourceId(), columnsByTable); + return ApiResponse.success("Update successful", agentDatasource); + } + catch (IllegalArgumentException e) { + log.warn("Invalid datasource columns update request, agentId={}, datasourceId={}, message={}", agentId, + dto.getDatasourceId(), e.getMessage()); + throw new InvalidInputException(e.getMessage()); + } + catch (Exception e) { + log.error("Error updating datasource columns", e); + throw new InternalServerException("Update failed: %s".formatted(e.getMessage())); + } + } + + @GetMapping("/{datasourceId}/tables/{tableName}/columns") + public ApiResponse> getVisibleTableColumns(@PathVariable Long agentId, + @PathVariable Integer datasourceId, @PathVariable String tableName) { + try { + List columns = agentDatasourceService.getVisibleTableColumns(agentId, datasourceId, tableName); + return ApiResponse.success("success", columns); + } + catch (IllegalArgumentException e) { + log.warn("Invalid visible columns request, agentId={}, datasourceId={}, tableName={}, message={}", agentId, + datasourceId, tableName, e.getMessage()); + throw new InvalidInputException(e.getMessage()); + } + catch (Exception e) { + log.error("Error loading visible columns, agentId={}, datasourceId={}, tableName={}", agentId, datasourceId, + tableName, e); + throw new InternalServerException("Load columns failed: %s".formatted(e.getMessage())); } } - /** Remove data source association from agent */ @DeleteMapping("/{datasourceId}") public ApiResponse removeDatasourceFromAgent(@PathVariable Long agentId, @PathVariable Integer datasourceId) { try { agentDatasourceService.removeDatasourceFromAgent(agentId, datasourceId); - return ApiResponse.success("数据源已移除"); + return ApiResponse.success("Datasource removed"); } catch (Exception e) { - throw new InternalServerException("移除失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Remove failed: %s".formatted(e.getMessage())); } } - /** 启用/禁用智能体的数据源 */ @PutMapping("/toggle") public ApiResponse toggleDatasourceForAgent(@PathVariable Long agentId, @RequestBody ToggleDatasourceDTO dto) { @@ -164,14 +217,17 @@ public ApiResponse toggleDatasourceForAgent(@PathVariable Long Boolean isActive = dto.getIsActive(); Integer datasourceId = dto.getDatasourceId(); if (isActive == null || datasourceId == null) { - throw new InvalidInputException("激活状态不能为空"); + throw new InvalidInputException("isActive and datasourceId cannot be null"); } AgentDatasource agentDatasource = agentDatasourceService.toggleDatasourceForAgent(agentId, dto.getDatasourceId(), isActive); - return ApiResponse.success(isActive ? "数据源已启用" : "数据源已禁用", agentDatasource); + return ApiResponse.success(isActive ? "Datasource enabled" : "Datasource disabled", agentDatasource); + } + catch (InvalidInputException e) { + throw e; } catch (Exception e) { - throw new InternalServerException("操作失败:%s".formatted(e.getMessage())); + throw new InternalServerException("Operation failed: %s".formatted(e.getMessage())); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java index c02a9800b..c1894509a 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/ChatController.java @@ -19,11 +19,15 @@ import com.alibaba.cloud.ai.dataagent.entity.ChatMessage; import com.alibaba.cloud.ai.dataagent.entity.ChatSession; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; +import com.alibaba.cloud.ai.dataagent.observability.AnswerTraceExplainStore; +import com.alibaba.cloud.ai.dataagent.observability.SessionTraceStore; import com.alibaba.cloud.ai.dataagent.service.chat.ChatMessageService; import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; import com.alibaba.cloud.ai.dataagent.service.chat.SessionTitleService; import com.alibaba.cloud.ai.dataagent.util.ReportTemplateUtil; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; @@ -34,6 +38,7 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.*; +import org.springframework.web.server.ResponseStatusException; import java.util.List; import java.util.Map; @@ -48,6 +53,8 @@ @RequiredArgsConstructor public class ChatController { + private static final String ANSWER_EXPLAIN_MESSAGE_TYPE = "answer-explain"; + private final ChatSessionService chatSessionService; private final ChatMessageService chatMessageService; @@ -56,6 +63,12 @@ public class ChatController { private final ReportTemplateUtil reportTemplateUtil; + private final SessionTraceStore sessionTraceStore; + + private final AnswerTraceExplainStore answerTraceExplainStore; + + private final ObjectMapper objectMapper; + /** * Get session list for an agent */ @@ -113,17 +126,98 @@ public ResponseEntity clearAgentSessions(@PathVariable(value = "id" * Get message list for a session */ @GetMapping("/sessions/{sessionId}/messages") - public ResponseEntity> getSessionMessages(@PathVariable(value = "sessionId") String sessionId) { - List messages = chatMessageService.findVisibleBySessionId(sessionId); + public ResponseEntity> getSessionMessages(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + List messages = chatMessageService.findVisibleBySessionId(sessionId, agentId); return ResponseEntity.ok(messages); } + @GetMapping("/sessions/{sessionId}/trace") + public ResponseEntity getLatestSessionTrace(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); + return sessionTraceStore.getLatestTrace(sessionId) + .>map(ResponseEntity::ok) + .orElseGet(() -> ResponseEntity.notFound().build()); + } + + @GetMapping("/sessions/{sessionId}/answers/latest/explain") + public ResponseEntity getLatestAnswerExplain(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); + return answerTraceExplainStore.getLatestExplain(sessionId) + .>map(ResponseEntity::ok) + .or(() -> loadLatestPersistedAnswerExplain(sessionId, agentId).map(ResponseEntity::ok)) + .orElseGet(() -> ResponseEntity.notFound().build()); + } + + @GetMapping("/sessions/{sessionId}/answers/{runtimeRequestId}/explain") + public ResponseEntity getAnswerExplain(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId, + @PathVariable(value = "runtimeRequestId") String runtimeRequestId) { + chatSessionService.requireSessionForAgent(sessionId, agentId); + return answerTraceExplainStore.getExplain(sessionId, runtimeRequestId) + .>map(ResponseEntity::ok) + .or(() -> loadPersistedAnswerExplain(sessionId, runtimeRequestId, agentId).map(ResponseEntity::ok)) + .orElseGet(() -> ResponseEntity.notFound().build()); + } + + private java.util.Optional loadLatestPersistedAnswerExplain(String sessionId, Long agentId) { + List snapshots = chatMessageService.findBySessionIdAndMessageType(sessionId, + ANSWER_EXPLAIN_MESSAGE_TYPE, agentId); + JsonNode latestExplainNode = null; + long latestUpdatedAt = Long.MIN_VALUE; + for (ChatMessage snapshot : snapshots) { + if (snapshot == null || !StringUtils.hasText(snapshot.getContent())) { + continue; + } + try { + JsonNode explainNode = objectMapper.readTree(snapshot.getContent()); + long updatedAt = explainNode.path("updatedAt").asLong(Long.MIN_VALUE); + if (latestExplainNode == null || updatedAt >= latestUpdatedAt) { + latestExplainNode = explainNode; + latestUpdatedAt = updatedAt; + } + } + catch (Exception ex) { + log.warn("Failed to parse persisted answer explain snapshot. sessionId={}, messageId={}", sessionId, + snapshot.getId(), ex); + } + } + return java.util.Optional.ofNullable(latestExplainNode); + } + + private java.util.Optional loadPersistedAnswerExplain(String sessionId, String runtimeRequestId, + Long agentId) { + if (!StringUtils.hasText(sessionId) || !StringUtils.hasText(runtimeRequestId)) { + return java.util.Optional.empty(); + } + List snapshots = chatMessageService.findBySessionIdAndMessageType(sessionId, + ANSWER_EXPLAIN_MESSAGE_TYPE, agentId); + for (ChatMessage snapshot : snapshots) { + if (snapshot == null || !StringUtils.hasText(snapshot.getContent())) { + continue; + } + try { + JsonNode explainNode = objectMapper.readTree(snapshot.getContent()); + if (runtimeRequestId.equals(explainNode.path("runtimeRequestId").asText())) { + return java.util.Optional.of(explainNode); + } + } + catch (Exception ex) { + log.warn("Failed to parse persisted answer explain snapshot. sessionId={}, messageId={}", sessionId, + snapshot.getId(), ex); + } + } + return java.util.Optional.empty(); + } + /** * Save message to session */ @PostMapping("/sessions/{sessionId}/messages") public ResponseEntity saveMessage(@PathVariable(value = "sessionId") String sessionId, - @RequestBody ChatMessageDTO request) { + @RequestParam(value = "agentId") Long agentId, @RequestBody ChatMessageDTO request) { try { if (request == null) { return ResponseEntity.badRequest().build(); @@ -136,17 +230,20 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" .metadata(request.getMetadata()) .build(); - ChatMessage savedMessage = chatMessageService.saveMessage(message); + ChatMessage savedMessage = chatMessageService.saveMessage(message, agentId); // Update session activity time - chatSessionService.updateSessionTime(sessionId); + chatSessionService.updateSessionTime(sessionId, agentId); - if (request.isTitleNeeded()) { + if (shouldGenerateTitle(request, savedMessage)) { sessionTitleService.scheduleTitleGeneration(sessionId, message.getContent()); } return ResponseEntity.ok(savedMessage); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Save message error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().build(); @@ -158,12 +255,15 @@ public ResponseEntity saveMessage(@PathVariable(value = "sessionId" */ @PutMapping("/sessions/{sessionId}/pin") public ResponseEntity pinSession(@PathVariable(value = "sessionId") String sessionId, - @RequestParam(value = "isPinned") Boolean isPinned) { + @RequestParam(value = "agentId") Long agentId, @RequestParam(value = "isPinned") Boolean isPinned) { try { - chatSessionService.pinSession(sessionId, isPinned); + chatSessionService.pinSession(sessionId, isPinned, agentId); String message = isPinned ? "会话已置顶" : "会话已取消置顶"; return ResponseEntity.ok(ApiResponse.success(message)); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Pin session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("操作失败")); @@ -175,15 +275,18 @@ public ResponseEntity pinSession(@PathVariable(value = "sessionId") */ @PutMapping("/sessions/{sessionId}/rename") public ResponseEntity renameSession(@PathVariable(value = "sessionId") String sessionId, - @RequestParam(value = "title") String title) { + @RequestParam(value = "agentId") Long agentId, @RequestParam(value = "title") String title) { try { if (!StringUtils.hasText(title)) { return ResponseEntity.badRequest().body(ApiResponse.error("标题不能为空")); } - chatSessionService.renameSession(sessionId, title.trim()); + chatSessionService.renameSession(sessionId, title.trim(), agentId); return ResponseEntity.ok(ApiResponse.success("会话已重命名")); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Rename session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("重命名失败")); @@ -194,11 +297,15 @@ public ResponseEntity renameSession(@PathVariable(value = "sessionI * Delete a single session */ @DeleteMapping("/sessions/{sessionId}") - public ResponseEntity deleteSession(@PathVariable(value = "sessionId") String sessionId) { + public ResponseEntity deleteSession(@PathVariable(value = "sessionId") String sessionId, + @RequestParam(value = "agentId") Long agentId) { try { - chatSessionService.deleteSession(sessionId); + chatSessionService.deleteSession(sessionId, agentId); return ResponseEntity.ok(ApiResponse.success("会话已删除")); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Delete session error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().body(ApiResponse.error("删除失败")); @@ -210,11 +317,12 @@ public ResponseEntity deleteSession(@PathVariable(value = "sessionI */ @PostMapping("/sessions/{sessionId}/reports/html") public ResponseEntity convertAndDownloadHtml(@PathVariable(value = "sessionId") String sessionId, - @RequestBody String content) { + @RequestParam(value = "agentId") Long agentId, @RequestBody String content) { try { if (!StringUtils.hasText(content)) { return ResponseEntity.badRequest().build(); } + chatSessionService.requireSessionForAgent(sessionId, agentId); log.debug("Download HTML report for session {}", sessionId); StringBuilder htmlContent = new StringBuilder(); htmlContent.append(reportTemplateUtil.getHeader()); @@ -227,10 +335,27 @@ public ResponseEntity convertAndDownloadHtml(@PathVariable(value = "sess headers.setContentDispositionFormData("attachment", filename); return ResponseEntity.ok().headers(headers).body(htmlContent.toString().getBytes(StandardCharsets.UTF_8)); } + catch (ResponseStatusException ex) { + throw ex; + } catch (Exception e) { log.error("Download HTML report error for session {}: {}", sessionId, e.getMessage(), e); return ResponseEntity.internalServerError().build(); } } + private boolean shouldGenerateTitle(ChatMessageDTO request, ChatMessage savedMessage) { + if (request == null || savedMessage == null) { + return false; + } + if (request.isTitleNeeded()) { + return true; + } + if (!"user".equalsIgnoreCase(request.getRole())) { + return false; + } + List sessionMessages = chatMessageService.findBySessionId(savedMessage.getSessionId()); + return sessionMessages.size() == 1; + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java index ebd9e221c..8658dc49c 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/DataAgentController.java @@ -15,9 +15,9 @@ */ package com.alibaba.cloud.ai.dataagent.controller; -import com.alibaba.cloud.ai.dataagent.agentscope.dto.GraphRequest; +import com.alibaba.cloud.ai.dataagent.agentscope.dto.AgentRequest; import com.alibaba.cloud.ai.dataagent.agentscope.service.AgentService; -import com.alibaba.cloud.ai.dataagent.agentscope.vo.GraphNodeResponse; +import com.alibaba.cloud.ai.dataagent.agentscope.vo.AgentResponse; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; @@ -28,8 +28,11 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ResponseStatusException; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; +import com.alibaba.cloud.ai.dataagent.service.chat.ChatSessionService; +import org.springframework.http.HttpStatus; import static com.alibaba.cloud.ai.dataagent.constant.Constant.STREAM_EVENT_COMPLETE; import static com.alibaba.cloud.ai.dataagent.constant.Constant.STREAM_EVENT_ERROR; @@ -43,26 +46,33 @@ public class DataAgentController { private final AgentService agentService; + private final ChatSessionService chatSessionService; + @GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux> streamSearch(@RequestParam("agentId") String agentId, - @RequestParam(value = "threadId", required = false) String threadId, @RequestParam("query") String query, + public Flux> streamSearch(@RequestParam("agentId") String agentId, + @RequestParam("threadId") String threadId, + @RequestParam(value = "runtimeRequestId", required = false) String runtimeRequestId, + @RequestParam("query") String query, + @RequestParam(value = "clarifyCheckEnabled", required = false) boolean clarifyCheckEnabled, @RequestParam(value = "humanFeedback", required = false) boolean humanFeedback, @RequestParam(value = "humanFeedbackContent", required = false) String humanFeedbackContent, - @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, - @RequestParam(value = "nl2sqlOnly", required = false) boolean nl2sqlOnly, ServerHttpResponse response) { + @RequestParam(value = "rejectedPlan", required = false) boolean rejectedPlan, ServerHttpResponse response) { + Long numericAgentId = parseAgentId(agentId); + chatSessionService.requireSessionForAgent(threadId, numericAgentId); response.getHeaders().add("Cache-Control", "no-cache"); response.getHeaders().add("Connection", "keep-alive"); response.getHeaders().add("Access-Control-Allow-Origin", "*"); - Sinks.Many> sink = Sinks.many().unicast().onBackpressureBuffer(); - GraphRequest request = GraphRequest.builder() + Sinks.Many> sink = Sinks.many().unicast().onBackpressureBuffer(); + AgentRequest request = AgentRequest.builder() .agentId(agentId) .threadId(threadId) + .runtimeRequestId(runtimeRequestId) .query(query) + .clarifyCheckEnabled(clarifyCheckEnabled) .humanFeedback(humanFeedback) .humanFeedbackContent(humanFeedbackContent) .rejectedPlan(rejectedPlan) - .nl2sqlOnly(nl2sqlOnly) .build(); agentService.graphStreamProcess(sink, request); @@ -89,4 +99,13 @@ public Flux> streamSearch(@RequestParam("agen .doOnComplete(() -> log.info("Aiagent stream completed successfully, threadId: {}", request.getThreadId())); } + private Long parseAgentId(String agentId) { + try { + return Long.valueOf(agentId); + } + catch (NumberFormatException ex) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "agentId must be numeric", ex); + } + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java index 865aed9d1..40eae5b21 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/GlobalExceptionHandler.java @@ -18,14 +18,17 @@ import com.alibaba.cloud.ai.dataagent.exception.InternalServerException; import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; +import org.springframework.web.bind.MethodArgumentNotValidException; +import org.springframework.web.bind.support.WebExchangeBindException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.bind.annotation.RestControllerAdvice; /** - * 全局异常处理器 (WebFlux 版本) + * Global exception handler. */ @Slf4j @RestControllerAdvice @@ -38,6 +41,29 @@ public ApiResponse handleInvalidInputException(InvalidInputException e) return ApiResponse.error(e.getMessage(), e.getData()); } + @ExceptionHandler(MethodArgumentNotValidException.class) + @ResponseStatus(HttpStatus.BAD_REQUEST) + public ApiResponse handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { + String message = buildValidationMessage(e.getBindingResult() + .getFieldErrors() + .stream() + .map(error -> error.getField() + ": " + error.getDefaultMessage()) + .collect(Collectors.toList())); + log.warn("Method argument not valid: {}", message); + return ApiResponse.error(message); + } + + @ExceptionHandler(WebExchangeBindException.class) + @ResponseStatus(HttpStatus.BAD_REQUEST) + public ApiResponse handleWebExchangeBindException(WebExchangeBindException e) { + String message = buildValidationMessage(e.getFieldErrors() + .stream() + .map(error -> error.getField() + ": " + error.getDefaultMessage()) + .collect(Collectors.toList())); + log.warn("Web exchange bind not valid: {}", message); + return ApiResponse.error(message); + } + @ExceptionHandler(InternalServerException.class) @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) public ApiResponse handleInternalServerException(InternalServerException e) { @@ -49,7 +75,17 @@ public ApiResponse handleInternalServerException(InternalServerException @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) public ApiResponse handleGenericException(Exception e) { log.error("Unexpected error: {}", e.getMessage(), e); - return ApiResponse.error("服务器内部错误"); + return ApiResponse.error("Internal server error"); + } + + private String buildValidationMessage(java.util.List messages) { + String message = messages.stream() + .filter(item -> item != null && !item.isBlank()) + .collect(Collectors.joining("; ")); + if (message.isBlank()) { + return "Request validation failed"; + } + return message; } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java index 4ce636377..ad4785497 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/controller/SemanticModelController.java @@ -17,7 +17,9 @@ import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelAddDTO; import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelBatchImportDTO; +import com.alibaba.cloud.ai.dataagent.dto.schema.SemanticModelUpdateDTO; import com.alibaba.cloud.ai.dataagent.entity.SemanticModel; +import com.alibaba.cloud.ai.dataagent.exception.InvalidInputException; import com.alibaba.cloud.ai.dataagent.service.semantic.SemanticModelService; import com.alibaba.cloud.ai.dataagent.vo.ApiResponse; import com.alibaba.cloud.ai.dataagent.vo.BatchImportResult; @@ -83,23 +85,26 @@ public ApiResponse get(@PathVariable(value = "id") Long id) { @PostMapping public ApiResponse create(@RequestBody @Validated SemanticModelAddDTO semanticModelAddDto) { - boolean success = semanticModelService.addSemanticModel(semanticModelAddDto); - if (success) { - return ApiResponse.success("Semantic model created successfully", true); - } - else { + try { + boolean success = semanticModelService.addSemanticModel(semanticModelAddDto); + if (success) { + return ApiResponse.success("Semantic model created successfully", true); + } return ApiResponse.error("Failed to create semantic model"); } + catch (IllegalArgumentException e) { + throw new InvalidInputException(e.getMessage()); + } } @PutMapping("/{id}") - public ApiResponse update(@PathVariable(value = "id") Long id, @RequestBody SemanticModel model) { + public ApiResponse update(@PathVariable(value = "id") Long id, + @RequestBody @Validated SemanticModelUpdateDTO semanticModelUpdateDto) { if (semanticModelService.getById(id) == null) { return ApiResponse.error("Semantic model not found"); } - model.setId(id); - semanticModelService.updateSemanticModel(id, model); - return ApiResponse.success("Semantic model updated successfully", model); + semanticModelService.updateSemanticModel(id, semanticModelUpdateDto); + return ApiResponse.success("Semantic model updated successfully", semanticModelService.getById(id)); } @DeleteMapping("/{id}") @@ -172,8 +177,9 @@ public ResponseEntity downloadTemplate() { @PostMapping(value = "/import/excel", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) public Mono> importExcel(@RequestPart("file") FilePart file, - @RequestPart("agentId") String agentId) { + @RequestPart("agentId") String agentId, @RequestPart("datasourceId") String datasourceId) { Long agentIdLong = Long.parseLong(agentId); + Integer datasourceIdInt = Integer.parseInt(datasourceId); String filename = file.filename(); return DataBufferUtils.join(file.content()).flatMap(dataBuffer -> { @@ -183,7 +189,8 @@ public Mono> importExcel(@RequestPart("file") Fil return Mono.fromCallable(() -> { try (InputStream inputStream = new ByteArrayInputStream(bytes)) { - BatchImportResult result = semanticModelService.importFromExcel(inputStream, filename, agentIdLong); + BatchImportResult result = semanticModelService.importFromExcel(inputStream, filename, agentIdLong, + datasourceIdInt); return ApiResponse.success("Excel导入完成", result); } }).subscribeOn(Schedulers.boundedElastic()); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java index 48f766e2f..cde7e64fa 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/SchemaInitRequest.java @@ -19,14 +19,19 @@ import java.io.Serializable; import java.util.List; +import java.util.Map; import java.util.Objects; public class SchemaInitRequest implements Serializable { private DbConfigBO dbConfig; + private Long agentId; + private List tables; + private Map> visibleColumnsByTable; + public DbConfigBO getDbConfig() { return dbConfig; } @@ -35,6 +40,14 @@ public void setDbConfig(DbConfigBO dbConfig) { this.dbConfig = dbConfig; } + public Long getAgentId() { + return agentId; + } + + public void setAgentId(Long agentId) { + this.agentId = agentId; + } + public List getTables() { return tables; } @@ -43,9 +56,18 @@ public void setTables(List tables) { this.tables = tables; } + public Map> getVisibleColumnsByTable() { + return visibleColumnsByTable; + } + + public void setVisibleColumnsByTable(Map> visibleColumnsByTable) { + this.visibleColumnsByTable = visibleColumnsByTable; + } + @Override public String toString() { - return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", tables=" + tables + '}'; + return "SchemaInitRequest{" + "dbConfig=" + dbConfig + ", agentId=" + agentId + ", tables=" + tables + + ", visibleColumnsByTable=" + visibleColumnsByTable + '}'; } @Override @@ -55,12 +77,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; SchemaInitRequest that = (SchemaInitRequest) o; - return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(tables, that.tables); + return Objects.equals(dbConfig, that.dbConfig) && Objects.equals(agentId, that.agentId) + && Objects.equals(tables, that.tables) + && Objects.equals(visibleColumnsByTable, that.visibleColumnsByTable); } @Override public int hashCode() { - return Objects.hash(dbConfig, tables); + return Objects.hash(dbConfig, agentId, tables, visibleColumnsByTable); } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java similarity index 69% rename from data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java rename to data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java index 5137496b5..a3d531bf7 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/bo/schema/ResultBO.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/TableColumnsSelectionDTO.java @@ -13,21 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.cloud.ai.dataagent.bo.schema; +package com.alibaba.cloud.ai.dataagent.dto.datasource; -import lombok.AllArgsConstructor; -import lombok.Builder; +import jakarta.validation.constraints.NotBlank; +import java.util.List; import lombok.Data; -import lombok.NoArgsConstructor; @Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class ResultBO { +public class TableColumnsSelectionDTO { - private ResultSetBO resultSet; + @NotBlank(message = "tableName cannot be blank") + private String tableName; - private DisplayStyleBO displayStyle; + private List columns; } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java new file mode 100644 index 000000000..7bc5247f4 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/datasource/UpdateDatasourceColumnsDTO.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.dto.datasource; + +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import java.util.List; +import lombok.Data; + +@Data +public class UpdateDatasourceColumnsDTO { + + @NotNull(message = "datasourceId cannot be null") + private Integer datasourceId; + + @Valid + private List tables; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java index 44bfdfda9..2a9a42136 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelAddDTO.java @@ -33,6 +33,9 @@ public class SemanticModelAddDTO { @NotNull(message = "智能体ID不能为空") private Long agentId; + @NotNull(message = "数据源ID不能为空") + private Integer datasourceId; + /** 关联的表名 */ @NotBlank(message = "表名不能为空") private String tableName; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java index 5d7d083aa..b96041ebc 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelBatchImportDTO.java @@ -37,6 +37,9 @@ public class SemanticModelBatchImportDTO { @NotNull(message = "智能体ID不能为空") private Long agentId; + @NotNull(message = "数据源ID不能为空") + private Integer datasourceId; + @NotEmpty(message = "导入数据不能为空") @Valid private List items; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java new file mode 100644 index 000000000..c2447f963 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/dto/schema/SemanticModelUpdateDTO.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.dto.schema; + +import jakarta.validation.constraints.NotBlank; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** 语义模型更新 DTO,仅允许编辑业务解释字段。 */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class SemanticModelUpdateDTO { + + @NotBlank(message = "业务名称不能为空") + private String businessName; + + private String synonyms; + + private String businessDescription; + + private String columnComment; + + @NotBlank(message = "数据类型不能为空") + private String dataType; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java index fdd7ae0b5..265417106 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasource.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonFormat; import java.time.LocalDateTime; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -50,6 +51,9 @@ public class AgentDatasource { // 当前数据源选中的表 private List selectTables; + // 当前数据源按表配置的字段白名单;未配置的表默认整表可见 + private Map> selectColumns; + public AgentDatasource(Long agentId, Integer datasourceId) { this.agentId = agentId; this.datasourceId = datasourceId; diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java new file mode 100644 index 000000000..3bf94bf55 --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/entity/AgentDatasourceColumn.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.entity; + +import com.fasterxml.jackson.annotation.JsonFormat; +import java.time.LocalDateTime; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.format.annotation.DateTimeFormat; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AgentDatasourceColumn { + + private Integer id; + + private Integer agentDatasourceId; + + private String tableName; + + private String columnName; + + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss") + @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss") + private LocalDateTime createTime; + + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss") + @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss") + private LocalDateTime updateTime; + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java new file mode 100644 index 000000000..4a75f44db --- /dev/null +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceColumnsMapper.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dataagent.mapper; + +import com.alibaba.cloud.ai.dataagent.entity.AgentDatasourceColumn; +import java.util.List; +import org.apache.ibatis.annotations.Delete; +import org.apache.ibatis.annotations.Insert; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; + +@Mapper +public interface AgentDatasourceColumnsMapper { + + @Select("SELECT * FROM agent_datasource_columns WHERE agent_datasource_id = #{agentDatasourceId} ORDER BY table_name, column_name") + List getAgentDatasourceColumns(@Param("agentDatasourceId") int agentDatasourceId); + + @Delete("DELETE FROM agent_datasource_columns WHERE agent_datasource_id = #{agentDatasourceId}") + int removeAllColumns(@Param("agentDatasourceId") int agentDatasourceId); + + @Delete("") + int removeColumnsOutsideTables(@Param("agentDatasourceId") int agentDatasourceId, + @Param("tables") List tables); + + @Insert("") + int insertColumns(@Param("rows") List rows); + +} diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java index f6cdb0b41..c93e68967 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/AgentDatasourceMapper.java @@ -47,6 +47,10 @@ public interface AgentDatasourceMapper { AgentDatasource selectByAgentIdAndDatasourceId(@Param("agentId") Long agentId, @Param("datasourceId") Integer datasourceId); + /** Query associations by datasource ID */ + @Select("SELECT * FROM agent_datasource WHERE datasource_id = #{datasourceId} ORDER BY create_time DESC") + List selectByDatasourceId(@Param("datasourceId") Integer datasourceId); + /** Disable all data sources for an agent */ @Update("UPDATE agent_datasource SET is_active = 0 WHERE agent_id = #{agentId}") int disableAllByAgentId(@Param("agentId") Long agentId); @@ -59,6 +63,9 @@ AgentDatasource selectByAgentIdAndDatasourceId(@Param("agentId") Long agentId, int countActiveByAgentIdExcluding(@Param("agentId") Long agentId, @Param("excludeDatasourceId") Integer excludeDatasourceId); + @Select("SELECT COUNT(*) FROM agent_datasource WHERE agent_id = #{agentId} AND is_active = 1") + int countActiveByAgentId(@Param("agentId") Long agentId); + @Delete("DELETE FROM agent_datasource WHERE datasource_id = #{datasourceId}") int deleteAllByDatasourceId(@Param("datasourceId") Integer datasourceId); diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java index 263614542..e28d68fbf 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/ChatMessageMapper.java @@ -36,32 +36,43 @@ public interface ChatMessageMapper { @Select(""" SELECT * FROM chat_message WHERE session_id = #{sessionId} - AND LOWER(TRIM(COALESCE(message_type, ''))) <> 'memory-text' + AND LOWER(TRIM(COALESCE(message_type, ''))) NOT IN ('memory-text', 'answer-explain') + AND LOWER(TRIM(COALESCE(message_type, ''))) NOT LIKE 'agentscope-state:%' ORDER BY create_time ASC """) List selectVisibleBySessionId(@Param("sessionId") String sessionId); @Select(""" - SELECT * FROM ( - SELECT * FROM chat_message - WHERE session_id = #{sessionId} - AND content IS NOT NULL - AND TRIM(content) <> '' - AND ( - LOWER(TRIM(COALESCE(message_type, ''))) = 'memory-text' - OR ( - LOWER(TRIM(COALESCE(role, ''))) IN ('assistant', 'system', 'tool', 'user') - AND LOWER(TRIM(COALESCE(message_type, ''))) NOT IN ('html', 'html-report', 'markdown-report', 'result-set') - AND LOWER(TRIM(COALESCE(message_type, ''))) IN ('', 'text') - ) - ) - ORDER BY create_time DESC - LIMIT #{limit} - ) recent_memory - ORDER BY create_time ASC + SELECT * FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) + ORDER BY create_time DESC, id DESC + """) + List selectBySessionIdAndMessageType(@Param("sessionId") String sessionId, + @Param("messageType") String messageType); + + @Select(""" + SELECT * FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) + ORDER BY create_time ASC, id ASC + """) + List selectStateBySessionIdAndMessageType(@Param("sessionId") String sessionId, + @Param("messageType") String messageType); + + @Select(""" + SELECT DISTINCT session_id FROM chat_message + WHERE LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + ORDER BY session_id ASC + """) + List selectSessionIdsWithAgentScopeState(); + + @Select(""" + SELECT COUNT(*) FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' """) - List selectRecentMemoryEligibleBySessionId(@Param("sessionId") String sessionId, - @Param("limit") int limit); + int countAgentScopeStateBySessionId(@Param("sessionId") String sessionId); /** * Query by id @@ -105,4 +116,24 @@ INSERT INTO chat_message (session_id, role, content, message_type, metadata, cre """) int deleteById(@Param("id") Long id); + @Delete(""" + DELETE FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) = LOWER(#{messageType}) + """) + int deleteBySessionIdAndMessageType(@Param("sessionId") String sessionId, @Param("messageType") String messageType); + + @Delete(""" + DELETE FROM chat_message + WHERE session_id = #{sessionId} + AND LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + """) + int deleteAgentScopeStateBySessionId(@Param("sessionId") String sessionId); + + @Delete(""" + DELETE FROM chat_message + WHERE LOWER(TRIM(COALESCE(message_type, ''))) LIKE 'agentscope-state:%' + """) + int deleteAllAgentScopeStateMessages(); + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java index 4da15fb5f..6cb4e1407 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/LogicalRelationMapper.java @@ -32,6 +32,14 @@ public interface LogicalRelationMapper { @Select("SELECT * FROM logical_relation WHERE id = #{id} AND is_deleted = 0") LogicalRelation selectById(@Param("id") Integer id); + @Select(""" + SELECT * FROM logical_relation + WHERE id = #{id} + AND datasource_id = #{datasourceId} + AND is_deleted = 0 + """) + LogicalRelation selectByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + /** * 根据数据源ID查询逻辑外键列表(未删除的) */ @@ -71,12 +79,51 @@ public interface LogicalRelationMapper { """) int updateById(LogicalRelation logicalRelation); + @Update(""" + + """) + int updateByIdAndDatasourceId(@Param("datasourceId") Integer datasourceId, + @Param("logicalRelation") LogicalRelation logicalRelation); + /** * 逻辑删除外键 */ @Update("UPDATE logical_relation SET is_deleted = 1, updated_time = NOW() WHERE id = #{id}") int deleteById(@Param("id") Integer id); + @Update(""" + UPDATE logical_relation + SET is_deleted = 1, updated_time = NOW() + WHERE id = #{id} + AND datasource_id = #{datasourceId} + AND is_deleted = 0 + """) + int deleteByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + + @Delete("DELETE FROM logical_relation WHERE id = #{id}") + int hardDeleteById(@Param("id") Integer id); + + @Delete(""" + DELETE FROM logical_relation + WHERE id = #{id} + AND datasource_id = #{datasourceId} + """) + int hardDeleteByIdAndDatasourceId(@Param("id") Integer id, @Param("datasourceId") Integer datasourceId); + /** * 逻辑删除数据源下的所有逻辑外键 */ @@ -99,4 +146,33 @@ int checkExists(@Param("datasourceId") Integer datasourceId, @Param("sourceTable @Param("sourceColumnName") String sourceColumnName, @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName); + @Select(""" + SELECT COUNT(*) FROM logical_relation + WHERE datasource_id = #{datasourceId} + AND source_table_name = #{sourceTableName} + AND source_column_name = #{sourceColumnName} + AND target_table_name = #{targetTableName} + AND target_column_name = #{targetColumnName} + AND is_deleted = 0 + AND id != #{excludeId} + """) + int checkExistsExcludingId(@Param("datasourceId") Integer datasourceId, + @Param("sourceTableName") String sourceTableName, @Param("sourceColumnName") String sourceColumnName, + @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName, + @Param("excludeId") Integer excludeId); + + @Select(""" + SELECT * FROM logical_relation + WHERE datasource_id = #{datasourceId} + AND source_table_name = #{sourceTableName} + AND source_column_name = #{sourceColumnName} + AND target_table_name = #{targetTableName} + AND target_column_name = #{targetColumnName} + AND is_deleted = 1 + ORDER BY updated_time DESC, id DESC + """) + List selectDeletedByBusinessKey(@Param("datasourceId") Integer datasourceId, + @Param("sourceTableName") String sourceTableName, @Param("sourceColumnName") String sourceColumnName, + @Param("targetTableName") String targetTableName, @Param("targetColumnName") String targetColumnName); + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java index c2d8ed935..d2ee4a295 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/mapper/SemanticModelMapper.java @@ -131,16 +131,11 @@ List selectEnabledByAgentIdAndDatasourceId(@Param("agentId") Long