Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import pandas as pd
import requests
import sqlparse
import sqlglot
from sqlglot import exp
from langchain.chat_models.base import BaseChatModel
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, BaseMessageChunk
Expand Down Expand Up @@ -40,7 +42,7 @@
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
from apps.db.db import exec_sql, get_version, check_connection
from apps.db.db import exec_sql, get_version, check_connection, get_sqlglot_dialect
from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
from apps.system.crud.parameter_manage import get_groups
Expand All @@ -66,6 +68,23 @@
i18n = I18n()



def extract_tables_from_sql(sql: str, ds_type: str = None) -> set:
"""从 SQL 中提取表名(使用 sqlglot 解析,可信)"""
tables = set()
dialect = get_sqlglot_dialect(ds_type)
try:
statements = sqlglot.parse(sql, dialect=dialect)
for stmt in statements:
if stmt:
for table in stmt.find_all(exp.Table):
if table.name:
tables.add(table.name)
except Exception:
pass
return tables


class LLMService:
ds: CoreDatasource
chat_question: ChatQuestion
Expand Down Expand Up @@ -106,6 +125,9 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
self.chunk_list = []
self.current_user = current_user
self.current_assistant = current_assistant

self.table_name_list = []

chat_id = chat_question.chat_id
chat: Chat | None = session.get(Chat, chat_id)
if not chat:
Expand Down Expand Up @@ -222,7 +244,7 @@ def is_running(self, timeout=0.5):

def init_messages(self, session: Session):

self.choose_table_schema(session)
self.table_name_list = self.choose_table_schema(session)

last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len(
self.generate_sql_logs) > 0 else []
Expand Down Expand Up @@ -404,6 +426,7 @@ def choose_table_schema(self, _session: Session):
self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
full_message=self.chat_question.db_schema)
return tables

def generate_analysis(self, _session: Session):
fields = self.get_fields_from_chart(_session)
Expand Down Expand Up @@ -1266,6 +1289,22 @@ def run_task(self, in_chat: bool = True, stream: bool = True,

sql_operate = OperationEnum.GENERATE_SQL
sql, tables = self.check_sql(session=_session, res=full_sql_text, operate=sql_operate)

# 表名安全检查:用 sqlglot 解析真实 SQL,不信任 AI 返回的 tables
actual_tables = extract_tables_from_sql(sql, ds_type=self.ds.type)
if not actual_tables:
raise SingleMessageError(
"SQL parsing failed: unable to extract table names. "
"This may indicate an unsupported SQL syntax or a security issue."
)
allowed_tables = set(self.table_name_list)
unauthorized_tables = actual_tables - allowed_tables
if unauthorized_tables:
raise SingleMessageError(
f"SQL contains unauthorized tables: {', '.join(unauthorized_tables)}. "
f"Allowed tables: {', '.join(allowed_tables)}"
)

if ((not self.current_assistant or is_page_embedded) and is_normal_user(
self.current_user)) or use_dynamic_ds:
sql_result = None
Expand Down
107 changes: 93 additions & 14 deletions backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,9 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
while sql.endswith(';'):
sql = sql[:-1]
# check execute sql only contain read operations
if not check_sql_read(sql, ds):
raise ValueError(f"SQL can only contain read operations")
is_safe, error_reason = check_sql_read(sql, ds)
if not is_safe:
raise ValueError(f"SQL can only contain read operations: {error_reason}")

db = DB.get_db(ds.type)
if db.connect_type == ConnectType.sqlalchemy:
Expand Down Expand Up @@ -716,11 +717,78 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
raise ParseSQLResultError(str(ex))


def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
def get_sqlglot_dialect(ds_type: str) -> str:
"""根据数据源类型获取 sqlglot dialect"""
if equals_ignore_case(ds_type, 'mysql', 'doris', 'starrocks'):
return 'mysql'
elif equals_ignore_case(ds_type, 'sqlServer'):
return 'tsql'
elif equals_ignore_case(ds_type, 'hive'):
return 'hive'
return None


# 通用危险函数(适用于所有数据库)
COMMON_DANGEROUS_FUNCTIONS = {'version', 'current_user', 'user', 'database'}

# 特定数据库的危险函数
DS_SPECIFIC_DANGEROUS_FUNCTIONS = {
'mysql': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
'doris': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
'starrocks': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
'postgresql': {'pg_read_file', 'pg_write_file', 'lo_import', 'lo_export'},
'sqlserver': {'EXEC', 'xp_cmdshell', 'sp_executesql'},
'oracle': {'UTL_FILE', 'DBMS_PIPE', 'DBMS_LOCK'},
'hive': {'ADD FILE', 'ADD JAR'},
}

# 危险模式正则表达式(用于检查特殊语法)
import re
DANGEROUS_PATTERNS = [
r'\bINTO\s+OUTFILE\b',
r'\bINTO\s+DUMPFILE\b',
r'\bEXEC\s*\(',
r'\bCOPY\s+.*\bTO\s+PROGRAM\b',
]


def get_dangerous_functions(ds_type: str) -> set:
"""获取危险函数(通用 + 特定数据源)"""
functions = COMMON_DANGEROUS_FUNCTIONS.copy()
ds_key = ds_type.lower() if ds_type else ''
if ds_key in DS_SPECIFIC_DANGEROUS_FUNCTIONS:
functions.update(DS_SPECIFIC_DANGEROUS_FUNCTIONS[ds_key])
return functions


def check_dangerous_functions(statements: list, ds_type: str) -> bool:
"""检查是否使用了危险函数,返回 True 表示安全"""
dangerous_functions = get_dangerous_functions(ds_type)
dangerous_functions_upper = {f.upper() for f in dangerous_functions}

for stmt in statements:
if stmt:
for func in stmt.find_all(exp.Anonymous):
if func.name.upper() in dangerous_functions_upper:
return False
return True


def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema) -> tuple[bool, str]:
"""
检查 SQL 是否为安全的只读查询
返回: (是否安全, 错误原因)
"""
try:
normalized_sql = sql.strip().lstrip("(").strip()
first_keyword = normalized_sql.split(None, 1)[0].upper() if normalized_sql else ""
allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}

# 根据配置决定是否允许元数据查询
if settings.SQLBOT_ALLOW_METADATA_QUERIES:
allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
else:
allowed_read_commands = {"SELECT", "WITH"}

denied_write_commands = {
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"TRUNCATE", "MERGE", "COPY", "REPLACE", "GRANT", "REVOKE",
Expand All @@ -730,21 +798,29 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
if not first_keyword:
raise ValueError("Parse SQL Error")
if first_keyword in denied_write_commands:
return False
return False, f"Write operation '{first_keyword}' is not allowed"

dialect = None
if equals_ignore_case(ds.type, 'mysql', 'doris', 'starrocks'):
dialect = 'mysql'
elif equals_ignore_case(ds.type, 'sqlServer'):
dialect = 'tsql'
elif equals_ignore_case(ds.type, 'hive'):
dialect = 'hive'
# 1. 使用正则检查特殊模式
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, sql, re.IGNORECASE):
return False, f"SQL contains dangerous pattern: {pattern}"

dialect = get_sqlglot_dialect(ds.type)
statements = sqlglot.parse(sql, dialect=dialect)

if not statements:
raise ValueError("Parse SQL Error")

# 2. 使用 sqlglot 检查函数调用
dangerous_functions = get_dangerous_functions(ds.type)
dangerous_functions_upper = {f.upper() for f in dangerous_functions}
for stmt in statements:
if stmt:
for func in stmt.find_all(exp.Anonymous):
if func.name.upper() in dangerous_functions_upper:
return False, f"SQL contains dangerous function: {func.name}"

# 3. 检查写操作类型
write_types = (
exp.Insert, exp.Update, exp.Delete,
exp.Create, exp.Drop, exp.Alter,
Expand All @@ -755,9 +831,12 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
if stmt is None:
continue
if isinstance(stmt, write_types):
return False
return False, f"SQL contains write operation: {type(stmt).__name__}"

if first_keyword not in allowed_read_commands:
return False, f"SQL command '{first_keyword}' is not allowed. Only SELECT and WITH are permitted"

return first_keyword in allowed_read_commands
return True, ""

except Exception as e:
raise ValueError(f"Parse SQL Error: {e}")
Expand Down
4 changes: 3 additions & 1 deletion backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
table_name_list = []
i = 0
for table in ds.tables:
# 如果传入了 table_list,则只处理在列表中的表
Expand All @@ -213,6 +214,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
schema_table += '\n]\n'
t_obj = {"id": i, "schema_table": schema_table}
tables.append(t_obj)
table_name_list.append(table.name)

# do table embedding
# if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
Expand All @@ -222,7 +224,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
for s in tables:
schema_str += s.get('schema_table')

return schema_str, []
return schema_str, table_name_list

def get_ds(self, ds_id: int, trans: Trans = None):
if self.ds_list:
Expand Down
4 changes: 4 additions & 0 deletions backend/common/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
GENERATE_SQL_QUERY_LIMIT_ENABLED: bool = True
GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT: int = 3

# 安全配置:是否允许元数据查询(SHOW/DESCRIBE/DESC/EXPLAIN)
# 默认关闭,防止通过元数据查询泄露数据库结构
SQLBOT_ALLOW_METADATA_QUERIES: bool = False

PARSE_REASONING_BLOCK_ENABLED: bool = True
DEFAULT_REASONING_CONTENT_START: str = '<think>'
DEFAULT_REASONING_CONTENT_END: str = '</think>'
Expand Down
Loading