LlamaIndex 自然语言SQL查询 Text2Sql
在大语言的众多应用场景中,对自然语言进行SQL转换,然后返回SQL查询结果(Text2Sql),是其中最常见的一种。官方文档已经提供了两篇文教程,一篇简单的,一篇复杂的,来实现这一过程。其中简单的一篇,地址是这个:Text-to-SQL Guide (Query Engine + Retriever)。本文将基于这篇教程,来实现这一过程。
Reading表结构
本文使用下面的表结构来进行示例,这个reading表,就是这个博客站点中 读书笔记 栏目的建表脚本,非常的简单:
CREATE TABLE `reading` ( `id` varchar(255) DEFAULT NULL, `code` varchar(255) NOT NULL COMMENT '编号', `title` varchar(255) NOT NULL COMMENT '标题', `cover` varchar(255) DEFAULT NULL COMMENT '封面地址', `category` varchar(255) NOT NULL COMMENT '类型', `summary` varchar(2000) NOT NULL COMMENT '简介', `recommend` int(255) NOT NULL COMMENT '推荐星级', `sort` int(255) NOT NULL, `create_date` datetime NOT NULL COMMENT '发表日期', PRIMARY KEY (`code`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
使用NLSQLTableQueryEngine
NLSQLTableQueryEngine(自然语言SQL表查询引擎)是一个高级的组件(相对于底层而言),它的主要作用就是将人类的查询语言问题,转换成可以执行的SQL查询语句,并返回执行的结果。
NLSQLTableQueryEngine 适用于下面的场景:
- 简单数据库结构:如日志分析库含5张表(用户、登录记录、设备),需快速查询“过去24小时失败登录次数”。
- 明确的查询范围:当你非常清楚答案就在某一张或某几张固定的表中时,用它最合适。
- 上下文限制,表结构信息需全部注入Prompt,易超LLM上下文窗口
- 对响应速度有要求:由于它是一步生成 SQL,没有复杂的中间过程,所以响应速度通常很快。
from sqlalchemy import DateTime, create_engine,MetaData,Table,Column,String,Integer,select,text from llama_index.core import SQLDatabase from llama_index.core.query_engine import NLSQLTableQueryEngine from llama_index.llms.openai import OpenAI from llama_index.core import Settings from dotenv import load_dotenv load_dotenv() llm = OpenAI(model="gpt-4.1-mini") engine = create_engine("mysql+pymysql://root:password@localhost/tracefact_cms?charset=utf8mb4") sql_database = SQLDatabase(engine, include_tables=["reading", "tech"]) metadata = MetaData() reading_table = Table( "reading", metadata, Column("code", String(255), primary_key=True), Column("title", String(255)), Column("category", String(255)), Column("summary", String(2000)), Column("recommend", Integer), Column("create_date", DateTime) ) query_engine = NLSQLTableQueryEngine( sql_database=sql_database, tables=[reading_table], llm=llm ) query_str = "最早发布的10篇文章的标题是什么?按照每行一个标题的格式回答" response = query_engine.query(query_str) print(response)
输出的结果如下:
程序员你伤不起 如何阅读一本书 身边的逻辑学 丑陋的中国人 时寒冰说:未来二十年,经济大趋势(现实篇) O2O:移动互联网时代的商业革命 编写高质量代码:改善C#程序的157个建议 将摄影还给大众:7天摄影入门 教训:互联网创业必须避免的八大误区 悟透JavaScript
使用 SQLTableRetrieverQueryEngine
在实际的中大型项目中,通常会有几十上百张表,如果将表结构全部发送给大模型,那么可能会超出模型的上下文限制,除此以外,还可能因为信息过载而导致模型“混淆”,从而无法生成准确的SQL。
SQLTableRetrieverQueryEngine 解决了这个问题。它不会一次性查看所有表,而是分两步走:
- 检索 (Retrieve):当用户提出问题时,它首先利用一个表结构索引(ObjectIndex),通过语义搜索(向量检索)来动态地找出与问题最相关的几张表。
- 查询 (Query):然后,它只将这几张最相关表的结构信息和用户的问题一起打包,发送给大语言模型生成SQL。
这个过程就像一个经验丰富的数据库管理员,他不会去翻阅整个数据库的文档,而是根据你的问题,迅速定位到可能包含答案的几张关键表,然后再进行详细查询。
from sqlalchemy import DateTime, create_engine,MetaData,Table,Column,String,Integer,select from llama_index.core import SQLDatabase from llama_index.core.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine from llama_index.llms.openai import OpenAI from llama_index.core import Settings from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core import VectorStoreIndex from llama_index.embeddings.huggingface import HuggingFaceEmbedding from dotenv import load_dotenv load_dotenv() llm = OpenAI(model="gpt-4.1-mini") engine = create_engine("mysql+pymysql://root:password@localhost/tracefact_cms?charset=utf8mb4") sql_database = SQLDatabase(engine, include_tables=["reading", "tech"]) # 配置嵌入模型 embed_model = HuggingFaceEmbedding( model_name = "../models/bge-base-zh-v1.5", normalize = True ) Settings.embed_model = embed_model # 为每张表添加一个SQLTableSchema对象 table_node_mapping = SQLTableNodeMapping(sql_database) table_schema_objs = [ (SQLTableSchema(table_name="reading", context_str="这里保存的是读书笔记")), (SQLTableSchema(table_name="tech", context_str="这里保存的是原创的技术文章")) ] obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, VectorStoreIndex, embed_model=embed_model, ) query_engine = SQLTableRetrieverQueryEngine( sql_database, obj_index.as_retriever(similarity_top_k=1) ) query_str = "最新发布的10篇读书笔记是什么?(输出结果用中文,每个标题单独一行)" response = query_engine.query(query_str) print("-----------------RESPONSE-----------------") print(response) print("\n-----------------RETRIEVED NODES-----------------") for s_node in response.source_nodes: print(f"Score: {s_node.score}") print(f"Content: \n{s_node.node.get_content()}") print(f"Metadata: \n{s_node.node.metadata}")
上面代码中,最需要注意的就是 SQLTableNodeMapping 和 SQLTableSchema。它们的名字具有一定的误导性:
- SQLTableNodeMapping:包含了实际的表结构(Schema),例如:列、列类型、备注等信息,也就是来自数据库的原始内容。并且会基于这些信息,为每一张表构建Node对象。它最终的作用是:会基于表的这些元信息,生成SQL语句。
- SQLTableSchema:它的任务是让你能用自然语言为某张表添加一个“注释”或“说明”。这对于大模型(LLM)至关重要,因为仅有title, category 这样的列名,模型可能无法准确理解这张表的真实用途。你提供的上下文(context_str)能极大地帮助模型理解表的业务含义。它的主要作用是:在检索阶段选出适用的表。
因为SQLTableSchema中的context_str只是进行业务说明,从而达到检索增强的效果,那么如果将context_str中的详细信息,写在数据库表的Comment中,那么就不需要再创建SQLTableSchema,只需要使用SqlTableNodeMapping就可以了。
最终的输出类似下面这样:
-----------------RESPONSE----------------- 最新发布的10篇读书笔记是: Building_Data-Driven_Applications_with_LlamaIndex 简约至上:交互式设计四策略 接口测试方法论 关于工作的9大谎言 刘澜极简管理学 时间黑客:用数据分析做个明白人 卓有成效的管理者 权力:为什么只为某些人所拥有 互联网广告系统 职场跃迁的60个管理思维 -----------------RETRIEVED NODES----------------- Score: None Content: [('Building_Data-Driven_Applications_with_LlamaIndex',), ('简约至上:交互式设计四策略',), ('接口测试方法论',), ('关于工作的9大谎言',), ('刘澜极简管理学',), ('时间黑客:用数据分析做个明白人',), ('卓有成效的管理者',), ('权力:为什么只为某些人所拥有',), ('互联网广告系统',), ('职场跃迁的60个管理思维',)] Metadata: {'sql_query': 'SELECT title\nFROM reading\nORDER BY create_date DESC\nLIMIT 10;', 'result': [('Building_Data-Driven_Applications_with_LlamaIndex',), ('简约至上:交互式设计四策略',), ('接口测试方法论',), ('关于工作的9大谎言',), (' 刘澜极简管理学',), ('时间黑客:用数据分析做个明白人',), ('卓有成效的管理者',), ('权力:为什么只为某些人所拥有',), ('互联网广告系统',), ('职场跃迁的60个管理思维',)], 'col_keys': ['title']}
使用 行检索器(row_retriever)
在上面的例子中,如果将问题提问,改为:“关于个人成长的读书笔记有哪些?(输出结果用中文,每个标题单独一行)”,结果返回为空。这是因为生成的SQL语句类似这样:SELECT title FROM reading WHERE category = '个人成长'。但是,reading表的category中,并没有“个人成长”这个类型,但有一个同义词:“成长”。
为了解决这个问题,我们可以使用row_retriever,对每个表的所有行进行向量化,然后在检索时,row_retriever会找出和查询最贴近的行。然后将这些检索到的行作为上下文,提升文本到SQL生成的最终表现。
# 使用 rows_retriver(行检索器) from llama_index.core import query_engine from llama_index.core.base.embeddings.base import similarity from llama_index.core import SQLDatabase from llama_index.core.query_engine import NLSQLTableQueryEngine from llama_index.llms.openai import OpenAI from llama_index.core import Settings from llama_index.core.schema import TextNode from llama_index.core.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core import VectorStoreIndex from llama_index.embeddings.huggingface import HuggingFaceEmbedding from sqlalchemy import DateTime, create_engine,MetaData,Table,Column,String,Integer,select,text from dotenv import load_dotenv # **** 配置模型 **** load_dotenv() llm = OpenAI(model="gpt-4.1-mini") embed_model = HuggingFaceEmbedding( model_name = "../models/bge-base-zh-v1.5", normalize = True ) Settings.embed_model = embed_model # **** 配置数据库 **** engine = create_engine("mysql+pymysql://root:password@localhost/tracefact_cms?charset=utf8mb4") sql_database = SQLDatabase(engine, include_tables=["reading", "tech"]) # **** 构建行检索器 **** query = "select code, title, summary, category, create_date from reading order by create_date desc" with engine.connect() as connection: results = connection.execute(text(query)).fetchall() reading_nodes = [] for row in results: node = TextNode( text = f"标题:{row.title},摘要:{row.summary}, 分类: { row.category }, 创建时间: { row.create_date }", ) reading_nodes.append(node) reading_nodes_index = VectorStoreIndex( reading_nodes, embed_model=embed_model ) reading_nodes_retriever = reading_nodes_index.as_retriever(similarity_top_k=1) rows_retrievers = { "reading" : reading_nodes_retriever } # **** 创建对象索引 **** # 构建Table schema信息 table_node_mapping = SQLTableNodeMapping(sql_database) # 为每张表添加一个SQLTableSchema对象 table_schema_objs = [ (SQLTableSchema(table_name="reading", context_str="这里保存的是读书笔记")), (SQLTableSchema(table_name="tech", context_str="这里保存的是原创的技术文章")) ] obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, VectorStoreIndex, embed_model=embed_model, ) # **** 构建QueryEngine **** query_engine = SQLTableRetrieverQueryEngine( sql_database, obj_index.as_retriever(similarity_top_k=1), rows_retrievers=rows_retrievers ) query_str = "关于个人成长的读书笔记有哪些?(输出结果用中文,每个标题单独一行)" response = query_engine.query(query_str) print("---------RESPONSE---------") print(response)
---------RESPONSE--------- 技巧:如何用一年时间获得十年的经验 见识:你最终能走多远,取决于见识 成事:冯唐品读曾国藩嘉言钞 好好学习:个人知识管理精进指南 把时间当作朋友
使用列检索器(col_retriever)
先回顾一下当前学习:
- 使用 SQLTableNodeMapping,相当于对表结构(Schema)进行索引;
- 使用 SQLTableSchema,相当于对表提供额外的业务说明,提升检索效果;
- 使用 行检索器(rows_retriever),相当于将表中的具体数据进行了检索。
使用行检索器,最大的问题,在于有些表中的数据高达百万,将全部数据进行检索不现实。行检索器,只适用于数据量相对较小的表。
除此以外,检索经常是针对某些列进行检索,而这些列中的取值,并不会太大。此时,就可以基于这些列的信息,构建列检索器:
# 使用 rows_retriver(列检索器) from llama_index.core import query_engine from llama_index.core.base.embeddings.base import similarity from llama_index.core import SQLDatabase from llama_index.core.query_engine import NLSQLTableQueryEngine from llama_index.llms.openai import OpenAI from llama_index.core import Settings from llama_index.core.schema import TextNode from llama_index.core.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core import VectorStoreIndex from llama_index.embeddings.huggingface import HuggingFaceEmbedding from sqlalchemy import DateTime, create_engine,MetaData,Table,Column,String,Integer,select,text from dotenv import load_dotenv # **** 配置模型 **** load_dotenv() llm = OpenAI(model="gpt-4.1-mini") embed_model = HuggingFaceEmbedding( model_name = "../models/bge-base-zh-v1.5", normalize = True ) Settings.embed_model = embed_model # **** 配置数据库 **** engine = create_engine("mysql+pymysql://root:password@localhost/tracefact_cms?charset=utf8mb4") sql_database = SQLDatabase(engine, include_tables=["reading", "tech"]) # **** 列内容信息 **** query = "select distinct category from reading" with engine.connect() as connection: results = connection.execute(text(query)).fetchall() categroy_cols_retrievers = {} for column_name in ["category"]: nodes = [TextNode(text=v[0]) for v in results] column_index = VectorStoreIndex( nodes, embed_model= embed_model ) column_retriever = column_index.as_retriever(similarity_top_k=1) categroy_cols_retrievers[column_name] = column_retriever cols_retrievers = { "category" : categroy_cols_retrievers } # **** schema信息 **** table_node_mapping = SQLTableNodeMapping(sql_database) # 为每张表添加一个SQLTableSchema对象 table_schema_objs = [ SQLTableSchema(table_name="reading", context_str="这里保存的是读书笔记"), SQLTableSchema(table_name="tech", context_str="这里保存的是原创的技术文章") ] obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, VectorStoreIndex, embed_model=embed_model, ) # **** 构建QueryEngine **** query_engine = SQLTableRetrieverQueryEngine( sql_database, obj_index.as_retriever(similarity_top_k=1), cols_retrievers=cols_retrievers, ) query_str = "关于个人成长的读书笔记有哪些?(输出结果用中文,每个标题单独一行)" response = query_engine.query(query_str) print("---------RESPONSE---------") print(response)
Text-to-SQL 检索器
在上面的例子当中,我们要么使用的是 SQLTableRetrieverQueryEngine、要么是NLSQLTableQueryEngine,都是QueryEngine,也就是高级别的查询引擎。也可以直接使用第一级的检索器:NLSQLRetriever。
from llama_index.core.retrievers import NLSQLRetriever from sqlalchemy import DateTime, create_engine,MetaData,Table,Column,String,Integer,select from llama_index.core import SQLDatabase from llama_index.llms.openai import OpenAI from dotenv import load_dotenv from llama_index.core import Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding load_dotenv() llm = OpenAI(model="gpt-4.1-mini") embed_model = HuggingFaceEmbedding( model_name = "../models/bge-base-zh-v1.5", normalize = True ) Settings.embed_model = embed_model engine = create_engine("mysql+pymysql://root:password@localhost/tracefact_cms?charset=utf8mb4") sql_database = SQLDatabase(engine, include_tables=["reading", "tech"]) nl_sql_retriever = NLSQLRetriever( sql_database, tables=["reading"], llm=llm, return_raw=True ) results = nl_sql_retriever.retrieve( "最新发表的3篇关于文学的读书笔记有哪些?" ) print(f"检索到了 {len(results)} 条记录:") for node in results: print(f"Content: {node.get_content()}") print(f"Metadata: {node.metadata}")
注意这里的参数 return_raw=True,它们的区别如下:
- return_raw=True: 1.自然语言 -> SQL 2.执行 SQL 3.返回查询结果
- return_raw=Faslse:1.自然语言 -> SQL 2. 执行 SQL 3. 将查询结果+原始问题喂给LLM 4. 返回LLM生成的答案
感谢阅读,希望这篇文章能给你带来帮助!