Skip to content

Commit a7556a7

Browse files
author
Pierre
committed
Update 15_text_to_sql.py
1 parent 77a2461 commit a7556a7

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/15_text_to_sql.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
It uses a sample e-commerce database schema and shows how to generate safe and efficient SQL queries.
44
55
Like example 14 (templated instructions), this example shows how to use variables in the agent's
6-
instructions. The template variables ({{ schema }} and {{ question }}) are automatically populated
6+
instructions. The template variables ({{ db_schema }} and {{ question }}) are automatically populated
77
from the input model's fields, allowing the instructions to adapt based on the input.
88
99
The example includes:
@@ -24,7 +24,7 @@
2424
class SQLGenerationInput(BaseModel):
2525
"""Input model for the SQL generation agent."""
2626

27-
schema: str = Field(
27+
db_schema: str = Field(
2828
description="The complete SQL schema with CREATE TABLE statements",
2929
)
3030
question: str = Field(
@@ -66,7 +66,7 @@ async def generate_sql(review_input: SQLGenerationInput) -> Run[SQLGenerationOut
6666
6. Include column names in GROUP BY rather than positions
6767
6868
Schema:
69-
{{ schema }}
69+
{{ db_schema }}
7070
7171
Question to convert to SQL:
7272
{{ question }}
@@ -124,7 +124,7 @@ async def main():
124124
print("-" * 50)
125125
run = await generate_sql(
126126
SQLGenerationInput(
127-
schema=schema,
127+
db_schema=schema,
128128
question="Show me all products that cost more than $100, ordered by price descending",
129129
),
130130
)
@@ -135,7 +135,7 @@ async def main():
135135
print("-" * 50)
136136
run = await generate_sql(
137137
SQLGenerationInput(
138-
schema=schema,
138+
db_schema=schema,
139139
question=(
140140
"List all customers with their total number of orders and total spend, "
141141
"only showing customers who have made at least 2 orders"
@@ -149,7 +149,7 @@ async def main():
149149
print("-" * 50)
150150
run = await generate_sql(
151151
SQLGenerationInput(
152-
schema=schema,
152+
db_schema=schema,
153153
question=(
154154
"What are the top 3 product categories by revenue in the last 30 days, "
155155
"including the number of unique customers who bought from each category?"

0 commit comments

Comments
 (0)