From 6d41d857958efb0b577acdbe01b588b2186fb9db Mon Sep 17 00:00:00 2001 From: ulleo Date: Thu, 11 Jun 2026 17:27:51 +0800 Subject: [PATCH] fix: Fix SQL injection / LLM Prompt Injection vulnerability causing unauthorized queries Security hardening: - Add SQLBOT_ALLOW_METADATA_QUERIES config option, disable SHOW/DESCRIBE/EXPLAIN by default - Add table whitelist check, use sqlglot to parse actual SQL table names and compare with authorized table list - Add dangerous function check, block LOAD_FILE, INTO OUTFILE, EXEC etc. by database type - Improve check_sql_read to return specific error reasons for better debugging --- backend/apps/chat/task/llm.py | 43 ++++++++++- backend/apps/db/db.py | 107 ++++++++++++++++++++++---- backend/apps/system/crud/assistant.py | 4 +- backend/common/core/config.py | 4 + 4 files changed, 141 insertions(+), 17 deletions(-) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 26af4845d..ca58c4647 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -13,6 +13,8 @@ import pandas as pd import requests import sqlparse +import sqlglot +from sqlglot import exp from langchain.chat_models.base import BaseChatModel from langchain_community.utilities import SQLDatabase from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, BaseMessageChunk @@ -40,7 +42,7 @@ from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource -from apps.db.db import exec_sql, get_version, check_connection +from apps.db.db import exec_sql, get_version, check_connection, get_sqlglot_dialect from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds from apps.system.crud.parameter_manage import get_groups @@ -66,6 +68,23 @@ i18n = I18n() + +def extract_tables_from_sql(sql: str, ds_type: str = None) -> set: + """从 SQL 中提取表名(使用 sqlglot 解析,可信)""" + tables = set() + dialect = get_sqlglot_dialect(ds_type) + try: + statements = sqlglot.parse(sql, dialect=dialect) + for stmt in statements: + if stmt: + for table in stmt.find_all(exp.Table): + if table.name: + tables.add(table.name) + except Exception: + pass + return tables + + class LLMService: ds: CoreDatasource chat_question: ChatQuestion @@ -106,6 +125,9 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C self.chunk_list = [] self.current_user = current_user self.current_assistant = current_assistant + + self.table_name_list = [] + chat_id = chat_question.chat_id chat: Chat | None = session.get(Chat, chat_id) if not chat: @@ -222,7 +244,7 @@ def is_running(self, timeout=0.5): def init_messages(self, session: Session): - self.choose_table_schema(session) + self.table_name_list = self.choose_table_schema(session) last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len( self.generate_sql_logs) > 0 else [] @@ -404,6 +426,7 @@ def choose_table_schema(self, _session: Session): self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session, log=self.current_logs[OperationEnum.CHOOSE_TABLE], full_message=self.chat_question.db_schema) + return tables def generate_analysis(self, _session: Session): fields = self.get_fields_from_chart(_session) @@ -1266,6 +1289,22 @@ def run_task(self, in_chat: bool = True, stream: bool = True, sql_operate = OperationEnum.GENERATE_SQL sql, tables = self.check_sql(session=_session, res=full_sql_text, operate=sql_operate) + + # 表名安全检查:用 sqlglot 解析真实 SQL,不信任 AI 返回的 tables + actual_tables = extract_tables_from_sql(sql, ds_type=self.ds.type) + if not actual_tables: + raise SingleMessageError( + "SQL parsing failed: unable to extract table names. " + "This may indicate an unsupported SQL syntax or a security issue." + ) + allowed_tables = set(self.table_name_list) + unauthorized_tables = actual_tables - allowed_tables + if unauthorized_tables: + raise SingleMessageError( + f"SQL contains unauthorized tables: {', '.join(unauthorized_tables)}. " + f"Allowed tables: {', '.join(allowed_tables)}" + ) + if ((not self.current_assistant or is_page_embedded) and is_normal_user( self.current_user)) or use_dynamic_ds: sql_result = None diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index b53bf2020..4e9aa4e1e 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -586,8 +586,9 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= while sql.endswith(';'): sql = sql[:-1] # check execute sql only contain read operations - if not check_sql_read(sql, ds): - raise ValueError(f"SQL can only contain read operations") + is_safe, error_reason = check_sql_read(sql, ds) + if not is_safe: + raise ValueError(f"SQL can only contain read operations: {error_reason}") db = DB.get_db(ds.type) if db.connect_type == ConnectType.sqlalchemy: @@ -716,11 +717,78 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= raise ParseSQLResultError(str(ex)) -def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): +def get_sqlglot_dialect(ds_type: str) -> str: + """根据数据源类型获取 sqlglot dialect""" + if equals_ignore_case(ds_type, 'mysql', 'doris', 'starrocks'): + return 'mysql' + elif equals_ignore_case(ds_type, 'sqlServer'): + return 'tsql' + elif equals_ignore_case(ds_type, 'hive'): + return 'hive' + return None + + +# 通用危险函数(适用于所有数据库) +COMMON_DANGEROUS_FUNCTIONS = {'version', 'current_user', 'user', 'database'} + +# 特定数据库的危险函数 +DS_SPECIFIC_DANGEROUS_FUNCTIONS = { + 'mysql': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'}, + 'doris': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'}, + 'starrocks': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'}, + 'postgresql': {'pg_read_file', 'pg_write_file', 'lo_import', 'lo_export'}, + 'sqlserver': {'EXEC', 'xp_cmdshell', 'sp_executesql'}, + 'oracle': {'UTL_FILE', 'DBMS_PIPE', 'DBMS_LOCK'}, + 'hive': {'ADD FILE', 'ADD JAR'}, +} + +# 危险模式正则表达式(用于检查特殊语法) +import re +DANGEROUS_PATTERNS = [ + r'\bINTO\s+OUTFILE\b', + r'\bINTO\s+DUMPFILE\b', + r'\bEXEC\s*\(', + r'\bCOPY\s+.*\bTO\s+PROGRAM\b', +] + + +def get_dangerous_functions(ds_type: str) -> set: + """获取危险函数(通用 + 特定数据源)""" + functions = COMMON_DANGEROUS_FUNCTIONS.copy() + ds_key = ds_type.lower() if ds_type else '' + if ds_key in DS_SPECIFIC_DANGEROUS_FUNCTIONS: + functions.update(DS_SPECIFIC_DANGEROUS_FUNCTIONS[ds_key]) + return functions + + +def check_dangerous_functions(statements: list, ds_type: str) -> bool: + """检查是否使用了危险函数,返回 True 表示安全""" + dangerous_functions = get_dangerous_functions(ds_type) + dangerous_functions_upper = {f.upper() for f in dangerous_functions} + + for stmt in statements: + if stmt: + for func in stmt.find_all(exp.Anonymous): + if func.name.upper() in dangerous_functions_upper: + return False + return True + + +def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema) -> tuple[bool, str]: + """ + 检查 SQL 是否为安全的只读查询 + 返回: (是否安全, 错误原因) + """ try: normalized_sql = sql.strip().lstrip("(").strip() first_keyword = normalized_sql.split(None, 1)[0].upper() if normalized_sql else "" - allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"} + + # 根据配置决定是否允许元数据查询 + if settings.SQLBOT_ALLOW_METADATA_QUERIES: + allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"} + else: + allowed_read_commands = {"SELECT", "WITH"} + denied_write_commands = { "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TRUNCATE", "MERGE", "COPY", "REPLACE", "GRANT", "REVOKE", @@ -730,21 +798,29 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): if not first_keyword: raise ValueError("Parse SQL Error") if first_keyword in denied_write_commands: - return False + return False, f"Write operation '{first_keyword}' is not allowed" - dialect = None - if equals_ignore_case(ds.type, 'mysql', 'doris', 'starrocks'): - dialect = 'mysql' - elif equals_ignore_case(ds.type, 'sqlServer'): - dialect = 'tsql' - elif equals_ignore_case(ds.type, 'hive'): - dialect = 'hive' + # 1. 使用正则检查特殊模式 + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, sql, re.IGNORECASE): + return False, f"SQL contains dangerous pattern: {pattern}" + dialect = get_sqlglot_dialect(ds.type) statements = sqlglot.parse(sql, dialect=dialect) if not statements: raise ValueError("Parse SQL Error") + # 2. 使用 sqlglot 检查函数调用 + dangerous_functions = get_dangerous_functions(ds.type) + dangerous_functions_upper = {f.upper() for f in dangerous_functions} + for stmt in statements: + if stmt: + for func in stmt.find_all(exp.Anonymous): + if func.name.upper() in dangerous_functions_upper: + return False, f"SQL contains dangerous function: {func.name}" + + # 3. 检查写操作类型 write_types = ( exp.Insert, exp.Update, exp.Delete, exp.Create, exp.Drop, exp.Alter, @@ -755,9 +831,12 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): if stmt is None: continue if isinstance(stmt, write_types): - return False + return False, f"SQL contains write operation: {type(stmt).__name__}" + + if first_keyword not in allowed_read_commands: + return False, f"SQL command '{first_keyword}' is not allowed. Only SELECT and WITH are permitted" - return first_keyword in allowed_read_commands + return True, "" except Exception as e: raise ValueError(f"Parse SQL Error: {e}") diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 1fa5eb27c..29b4b5229 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -187,6 +187,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True, db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" tables = [] + table_name_list = [] i = 0 for table in ds.tables: # 如果传入了 table_list,则只处理在列表中的表 @@ -213,6 +214,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True, schema_table += '\n]\n' t_obj = {"id": i, "schema_table": schema_table} tables.append(t_obj) + table_name_list.append(table.name) # do table embedding # if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: @@ -222,7 +224,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True, for s in tables: schema_str += s.get('schema_table') - return schema_str, [] + return schema_str, table_name_list def get_ds(self, ds_id: int, trans: Trans = None): if self.ds_list: diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 1b3cc24ef..6782c1609 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -111,6 +111,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: GENERATE_SQL_QUERY_LIMIT_ENABLED: bool = True GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT: int = 3 + # 安全配置:是否允许元数据查询(SHOW/DESCRIBE/DESC/EXPLAIN) + # 默认关闭,防止通过元数据查询泄露数据库结构 + SQLBOT_ALLOW_METADATA_QUERIES: bool = False + PARSE_REASONING_BLOCK_ENABLED: bool = True DEFAULT_REASONING_CONTENT_START: str = '' DEFAULT_REASONING_CONTENT_END: str = ''