Как создать LLM-агента для перевода текста в SQL-запросы

MLPops

MLPops / January 03, 2025

6 мин.

В этом руководстве мы разберём, как реализовать агента, который использует SQL с помощью библиотеки transformers.agents

🤔 Преимущества по сравнению со стандартным Text-to-SQL

Стандартные конвейеры Text-to-SQL часто ненадёжны: сгенерированный SQL-запрос может быть некорректным. Ещё хуже, если запрос возвращает неверные или бесполезные данные без явной ошибки.

👉 Агентная система способна критически анализировать результаты выполнения запросов и решать, нужно ли их изменить, что значительно повышает производительность.

Давайте создадим такого агента! 💪

🛠 Настройка SQL-таблиц

Сначала создадим базу данных SQLite и таблицу для хранения товаров:

from sqlalchemy import (
    create_engine, MetaData,
    Table, Column, String,
    Integer, Float, insert, inspect,text,
)
 
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
 
table_name = "catalog"
receipts = Table(
    table_name,
    metadata_obj,
    Column("product_id", Integer, primary_key=True),
    Column("product_name", String(64), primary_key=True),
    Column("price", Float),
    Column("brand", String(32)),
    Column("category", String(32)),
)
metadata_obj.create_all(engine)

Заполним нашу таблицу данными

rows = [
    {"product_id": 1, "product_name": "Apple iPhone 16 Pro", "price": 143351, "category": "phone", "brand":"Apple"},
    {"product_id": 2, "product_name": "Apple iPhone 12 eSIM+SIM", "price": 57317, "category": "phone", "brand":"Apple"},
    {"product_id": 3, "product_name": "Samsung A35 5G 8/256 Гб", "price": 32009, "category": "phone", "brand":"Samsung"},
    {"product_id": 4, "product_name": "Samsung A55 5G 5/256 Гб", "price": 38799, "category": "phone", "brand":"Samsung"},
    {"product_id": 5, "product_name": "Samsung Galaxy Z Fold 5 Global 12/256 Гб", "price": 121991, "category": "phone", "brand":"Samsung"},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

Проверим данные с помощью простого SQL-запроса:

with engine.connect() as con:
    rows = con.execute(text("""SELECT * from catalog"""))
    for row in rows:
        print(row)

Результ:

(1, 'Apple iPhone 16 Pro', 143351.0, 'Apple', 'phone')
(2, 'Apple iPhone 12 eSIM+SIM', 57317.0, 'Apple', 'phone')
(3, 'Samsung A35 5G 8/256 Гб', 32009.0, 'Samsung', 'phone')
(4, 'Samsung A55 5G 5/256 Гб', 38799.0, 'Samsung', 'phone')
(5, 'Samsung Galaxy Z Fold 5 Global 12/256 Гб', 121991.0, 'Samsung', 'phone')

🤖 Создание агента

Теперь сделаем нашу таблицу доступной для инструмента.

Описание таблицы

Сначала опишем структуру таблицы для использования в подсказке LLM:

inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("catalog")]
 
table_description = "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)

Результ:

Columns:
  - product_id: INTEGER
  - product_name: VARCHAR(64)
  - price: FLOAT
  - brand: VARCHAR(32)
  - category: VARCHAR(32)

Инструмент для выполнения SQL-запросов

Теперь создадим функцию-инструмент для выполнения запросов:

from transformers.agents import tool
 
@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'catalog'. Its description is as follows:
        Columns:
        - product_id: INTEGER
        - product_name: VARCHAR(64)
        - price: FLOAT
        - brand: VARCHAR(32)
        - category: VARCHAR(32)
 
    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output

Реализация агента

Используем ReactCodeAgent из transformers.agents, чтобы создать агента:

from transformers.agents import ReactCodeAgent, HfApiEngine
 
agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("Qwen/Qwen2.5-Coder-32B-Instruct", token="<ACCESS_TOKEN>"),
)

Вот тут сделаем малое отступление:

У transformers есть две реализации для llm_engine - HfApiEngine и TransformersEngine, рассмотрим их чуть подробнее.

HfApiEngine — это класс из библиотеки Hugging Face Transformers, который предоставляет интерфейс для взаимодействия с языковыми моделями (LLM) через Hugging Face Inference API. Этот движок используется для выполнения задач, требующих мощных моделей, которые могут быть недоступны локально.

TransformersEngine — это локальный движок для выполнения задач с использованием языковых моделей (LLM), предоставляемый библиотекой Hugging Face Transformers. В отличие от HfApiEngine, этот движок выполняет все вычисления локально.

Продолжим

Делаем первый запрос:

agent.run("Можешь найти самый дещевый iPhone?")

Результат:

======== New task ========
Можешь найти самый дещевый iPhone?
=== Agent thoughts:
Though: The task requires finding the cheapest iPhone in the catalog. I will use the `sql_engine` tool to query the catalog for iPhone products, then determine the minimum price.
>>> Agent is executing the code below:
iphone_query = sql_engine(query="SELECT product_name, price FROM catalog WHERE product_name LIKE '%iPhone%' ORDER BY price ASC LIMIT 1")
print(iphone_query)
====
Print outputs:
 
('Apple iPhone 12 eSIM+SIM', 57317.0)
 
=== Agent thoughts:
Thought: The query has returned the cheapest iPhone in the catalog, which is the "Apple iPhone 12 eSIM+SIM" with a price of 57317.0. I will now use the `final_answer` tool to provide the final answer.
>>> Agent is executing the code below:
final_answer("The cheapest iPhone in the catalog is the Apple iPhone 12 eSIM+SIM with a price of 57317.0.")
====
Print outputs:
 
Last output from code snippet:
The cheapest iPhone in the catalog is the Apple iPhone 12 eSIM+SIM with a price of 57317.0.
 

🌟 Усложняем задачу: объединение таблиц

Добавим вторую таблицу с комментариями и оценками пользователей:

table_name = "comments"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("text", String(16), primary_key=True),
    Column("rating", Integer),
    
)
metadata_obj.create_all(engine)
 
rows = [
    {"receipt_id": 1, "text": "Это лучший телефон, спасибо за даставку", "rating": 5},
    {"receipt_id": 1, "text": "Хоть и дорогой, но лучший", "rating": 5},
    {"receipt_id": 5, "text": "Самый не удобный телефон", "rating": 2},
    {"receipt_id": 5, "text": "Жутко тормоизит и бесит этим", "rating": 1},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

Обновим описание инструмента для работы с обеими таблицами:

updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:"""
 
inspector = inspect(engine)
for table in ["catalog", "comments"]:
    columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]
 
    table_description = f"Table '{table}':\n"
 
    table_description += "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
    updated_description += "\n\n" + table_description
 
print(updated_description)

Теперь агент может выполнять более сложные запросы:

sql_engine.description = updated_description
 
agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("Qwen/Qwen2.5-Coder-32B-Instruct", token="<ACCESS_TOKEN>"),
)
 
agent.run("Какой телефон имеет самые высокий рейтинг?")

Результат:

======== New task ========
Какой телефон имеет самые высокий рейтинг?
=== Agent thoughts:
Thought: To find the phone with the highest rating, I need to aggregate the ratings for each product, determine the average rating for each phone, and then find the phone with the highest average rating. I will first fetch all the comments from the 'comments' table, then group them by product, calculate the average rating for each product, and finally find the product with the highest average rating.
>>> Agent is executing the code below:
# Get all comments with ratings
query = "SELECT receipt_id, rating FROM comments"
comments_output = sql_engine(query=query)
print(comments_output)
====
Print outputs:
 
(1, 5)
(1, 5)
(5, 2)
(5, 1)
 
=== Agent thoughts:
Thought: I have the ratings for the products. Now, I need to process this data to calculate the average rating for each product and then find the product with the highest average rating.
>>> Agent is executing the code below:
import statistics
 
# Data from the comments table
comments_data = [(1, 5), (1, 5), (5, 2), (5, 1)]
...
('Apple iPhone 16 Pro',
Final answer:
 
('Apple iPhone 16 Pro',

✅ Итог

Теперь у вас есть мощная система Text-to-SQL! С её помощью можно легко обрабатывать запросы на естественном языке и преобразовывать их в корректные SQL-запросы. ✨