Skip to content

Commit 3686ff1

Browse files
henomisSimone Vellei
andauthored
add basic psql support (#234)
* add basic psql support * fix linter * fix linter --------- Co-authored-by: Simone Vellei <simone.vellei@cybus.io>
1 parent b0be2f1 commit 3686ff1

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

linglet/sql/psql.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package sql
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
)
7+
8+
//nolint:lll
9+
var psqlSystemPromptTemplate = `
10+
You are a Postgresql expert. Given an input question, create a syntactically correct psql query to run. Do not add any extra information to the query. The query must be usable as-is.
11+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {{.top_k}} results using the LIMIT clause as per Postgresql. You can order the results to return the most informative data in the database.
12+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
13+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
14+
Pay attention to use date('now') function to get the current date, if the question involves "today". Do not use markdown to format the query.`
15+
16+
// psqlSchema retrieves the schema information for all tables in a PostgreSQL database.
17+
//
18+
//nolint:funlen,gocognit
19+
func (s *SQL) psqlSchema() (*string, error) {
20+
rows, err := s.db.Query(`
21+
SELECT
22+
c.table_name,
23+
c.column_name,
24+
c.data_type,
25+
c.column_default,
26+
c.is_nullable,
27+
tc.constraint_type,
28+
kcu.constraint_name,
29+
ccu.table_name AS foreign_table_name,
30+
ccu.column_name AS foreign_column_name
31+
FROM
32+
information_schema.columns c
33+
LEFT JOIN
34+
information_schema.key_column_usage kcu ON c.table_name = kcu.table_name AND c.column_name = kcu.column_name
35+
LEFT JOIN
36+
information_schema.table_constraints tc ON kcu.constraint_name = tc.constraint_name
37+
LEFT JOIN
38+
information_schema.constraint_column_usage ccu ON tc.constraint_name = ccu.constraint_name
39+
WHERE
40+
c.table_schema = 'public'
41+
ORDER BY
42+
c.table_name, c.ordinal_position;
43+
`)
44+
if err != nil {
45+
return nil, fmt.Errorf("querying schema: %w", err)
46+
}
47+
defer rows.Close()
48+
49+
schema := ""
50+
currentTable := ""
51+
52+
for rows.Next() {
53+
//nolint:lll
54+
var tableName, columnName, dataType, columnDefault, isNullable, constraintType, constraintName, foreignTableName, foreignColumnName sql.NullString
55+
//nolint:lll
56+
if rowsErr := rows.Scan(&tableName, &columnName, &dataType, &columnDefault, &isNullable, &constraintType, &constraintName, &foreignTableName, &foreignColumnName); rowsErr != nil {
57+
return nil, fmt.Errorf("scanning row: %w", rowsErr)
58+
}
59+
60+
//nolin:nestif
61+
if tableName.Valid && tableName.String != currentTable {
62+
if currentTable != "" {
63+
schema += "\n" // Add a newline before a new table
64+
}
65+
schema += fmt.Sprintf("Table: %s\n", tableName.String)
66+
currentTable = tableName.String
67+
}
68+
69+
//nolint:nestif
70+
if columnName.Valid {
71+
schema += fmt.Sprintf(" Column: %s, Type: %s", columnName.String, dataType.String)
72+
73+
if columnDefault.Valid {
74+
schema += fmt.Sprintf(", Default: %s", columnDefault.String)
75+
}
76+
77+
if isNullable.Valid {
78+
schema += fmt.Sprintf(", Nullable: %s", isNullable.String)
79+
}
80+
81+
if constraintType.Valid {
82+
schema += fmt.Sprintf(", Constraint: %s (%s)", constraintType.String, constraintName.String)
83+
if foreignTableName.Valid && foreignColumnName.Valid {
84+
schema += fmt.Sprintf(", References: %s(%s)", foreignTableName.String, foreignColumnName.String)
85+
}
86+
}
87+
88+
schema += "\n"
89+
}
90+
}
91+
92+
if rowsErr := rows.Err(); rowsErr != nil {
93+
return nil, fmt.Errorf("rows error: %w", rowsErr)
94+
}
95+
96+
return &schema, nil
97+
}

linglet/sql/sql.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ func (s *SQL) schema() (*string, error) {
5555
driverType := fmt.Sprintf("%T", s.db.Driver())
5656
if strings.Contains(driverType, "sqlite") {
5757
return s.sqliteSchema()
58+
} else if strings.Contains(driverType, "pq.Driver") {
59+
return s.psqlSchema()
5860
}
5961

6062
return nil, fmt.Errorf("unsupported database driver %s", driverType)
@@ -64,6 +66,8 @@ func (s *SQL) systemPrompt() (*string, error) {
6466
driverType := fmt.Sprintf("%T", s.db.Driver())
6567
if strings.Contains(driverType, "sqlite") {
6668
return &sqliteSystemPromptTemplate, nil
69+
} else if strings.Contains(driverType, "pq.Driver") {
70+
return &psqlSystemPromptTemplate, nil
6771
}
6872

6973
return nil, fmt.Errorf("unsupported database driver %s", driverType)

linglet/sql/sqlite.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ You are a SQLite expert. Given an input question, create a syntactically correct
66
Unless the user specifies in the question a specific number of examples to obtain, query for at most {{.top_k}} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
77
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
88
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
9-
Pay attention to use date('now') function to get the current date, if the question involves "today".`
9+
Pay attention to use date('now') function to get the current date, if the question involves "today". Do not use markdown to format the query.`
1010

1111
func (s *SQL) sqliteSchema() (*string, error) {
1212
rows, err := s.db.Query("SELECT sql FROM sqlite_schema WHERE sql IS NOT NULL")

0 commit comments

Comments
 (0)