From 3bb1e69a74ab8027e0db4a769bcc5ae331beeec3 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 8 Mar 2025 14:30:59 +0100 Subject: [PATCH 1/3] add basic psql support --- linglet/sql/psql.go | 92 +++++++++++++++++++++++++++++++++++++++++++ linglet/sql/sql.go | 4 ++ linglet/sql/sqlite.go | 2 +- 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 linglet/sql/psql.go diff --git a/linglet/sql/psql.go b/linglet/sql/psql.go new file mode 100644 index 00000000..4d92cb7e --- /dev/null +++ b/linglet/sql/psql.go @@ -0,0 +1,92 @@ +package sql + +import ( + "database/sql" + "fmt" +) + +//nolint:lll +var psqlSystemPromptTemplate = ` +You are a Postgresql expert. Given an input question, create a syntactically correct psql query to run. Do not add any extra information to the query. The query must be usable as-is. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {{.top_k}} results using the LIMIT clause as per Postgresql. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use date('now') function to get the current date, if the question involves "today". Do not use markdown to format the query.` + +// psqlSchema retrieves the schema information for all tables in a PostgreSQL database. +func (s *SQL) psqlSchema() (*string, error) { + rows, err := s.db.Query(` + SELECT + c.table_name, + c.column_name, + c.data_type, + c.column_default, + c.is_nullable, + tc.constraint_type, + kcu.constraint_name, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name + FROM + information_schema.columns c + LEFT JOIN + information_schema.key_column_usage kcu ON c.table_name = kcu.table_name AND c.column_name = kcu.column_name + LEFT JOIN + information_schema.table_constraints tc ON kcu.constraint_name = tc.constraint_name + LEFT JOIN + information_schema.constraint_column_usage ccu ON tc.constraint_name = ccu.constraint_name + WHERE + c.table_schema = 'public' + ORDER BY + c.table_name, c.ordinal_position; +`) + if err != nil { + return nil, fmt.Errorf("querying schema: %w", err) + } + defer rows.Close() + + schema := "" + currentTable := "" + + for rows.Next() { + var tableName, columnName, dataType, columnDefault, isNullable, constraintType, constraintName, foreignTableName, foreignColumnName sql.NullString + if err := rows.Scan(&tableName, &columnName, &dataType, &columnDefault, &isNullable, &constraintType, &constraintName, &foreignTableName, &foreignColumnName); err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + if tableName.Valid && tableName.String != currentTable { + if currentTable != "" { + schema += "\n" // Add a newline before a new table + } + schema += fmt.Sprintf("Table: %s\n", tableName.String) + currentTable = tableName.String + } + + if columnName.Valid { + schema += fmt.Sprintf(" Column: %s, Type: %s", columnName.String, dataType.String) + + if columnDefault.Valid { + schema += fmt.Sprintf(", Default: %s", columnDefault.String) + } + + if isNullable.Valid { + schema += fmt.Sprintf(", Nullable: %s", isNullable.String) + } + + if constraintType.Valid { + schema += fmt.Sprintf(", Constraint: %s (%s)", constraintType.String, constraintName.String) + if foreignTableName.Valid && foreignColumnName.Valid { + schema += fmt.Sprintf(", References: %s(%s)", foreignTableName.String, foreignColumnName.String) + } + } + + schema += "\n" + } + + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("rows error: %w", err) + } + + return &schema, nil +} diff --git a/linglet/sql/sql.go b/linglet/sql/sql.go index 131720f8..cf7d3a75 100644 --- a/linglet/sql/sql.go +++ b/linglet/sql/sql.go @@ -55,6 +55,8 @@ func (s *SQL) schema() (*string, error) { driverType := fmt.Sprintf("%T", s.db.Driver()) if strings.Contains(driverType, "sqlite") { return s.sqliteSchema() + } else if strings.Contains(driverType, "pq.Driver") { + return s.psqlSchema() } return nil, fmt.Errorf("unsupported database driver %s", driverType) @@ -64,6 +66,8 @@ func (s *SQL) systemPrompt() (*string, error) { driverType := fmt.Sprintf("%T", s.db.Driver()) if strings.Contains(driverType, "sqlite") { return &sqliteSystemPromptTemplate, nil + } else if strings.Contains(driverType, "pq.Driver") { + return &psqlSystemPromptTemplate, nil } return nil, fmt.Errorf("unsupported database driver %s", driverType) diff --git a/linglet/sql/sqlite.go b/linglet/sql/sqlite.go index 5e3a9090..49b99e80 100644 --- a/linglet/sql/sqlite.go +++ b/linglet/sql/sqlite.go @@ -6,7 +6,7 @@ You are a SQLite expert. Given an input question, create a syntactically correct Unless the user specifies in the question a specific number of examples to obtain, query for at most {{.top_k}} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. -Pay attention to use date('now') function to get the current date, if the question involves "today".` +Pay attention to use date('now') function to get the current date, if the question involves "today". Do not use markdown to format the query.` func (s *SQL) sqliteSchema() (*string, error) { rows, err := s.db.Query("SELECT sql FROM sqlite_schema WHERE sql IS NOT NULL") From 79ce0d37c3eceede5cae3174044b28b1333904c3 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 8 Mar 2025 15:07:53 +0100 Subject: [PATCH 2/3] fix linter --- linglet/sql/psql.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/linglet/sql/psql.go b/linglet/sql/psql.go index 4d92cb7e..c04e00c0 100644 --- a/linglet/sql/psql.go +++ b/linglet/sql/psql.go @@ -14,6 +14,8 @@ Pay attention to use only the column names you can see in the tables below. Be c Pay attention to use date('now') function to get the current date, if the question involves "today". Do not use markdown to format the query.` // psqlSchema retrieves the schema information for all tables in a PostgreSQL database. +// +//nolint:funlen,gocognit func (s *SQL) psqlSchema() (*string, error) { rows, err := s.db.Query(` SELECT @@ -48,11 +50,14 @@ func (s *SQL) psqlSchema() (*string, error) { currentTable := "" for rows.Next() { + //nolint:lll var tableName, columnName, dataType, columnDefault, isNullable, constraintType, constraintName, foreignTableName, foreignColumnName sql.NullString + //nolint:lll if err := rows.Scan(&tableName, &columnName, &dataType, &columnDefault, &isNullable, &constraintType, &constraintName, &foreignTableName, &foreignColumnName); err != nil { return nil, fmt.Errorf("scanning row: %w", err) } + //nolin:nestif if tableName.Valid && tableName.String != currentTable { if currentTable != "" { schema += "\n" // Add a newline before a new table @@ -81,7 +86,6 @@ func (s *SQL) psqlSchema() (*string, error) { schema += "\n" } - } if err := rows.Err(); err != nil { From 2e1989f4bc11016a35558258d2ad8bdad24e654d Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 8 Mar 2025 15:11:10 +0100 Subject: [PATCH 3/3] fix linter --- linglet/sql/psql.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/linglet/sql/psql.go b/linglet/sql/psql.go index c04e00c0..8197e69b 100644 --- a/linglet/sql/psql.go +++ b/linglet/sql/psql.go @@ -53,8 +53,8 @@ func (s *SQL) psqlSchema() (*string, error) { //nolint:lll var tableName, columnName, dataType, columnDefault, isNullable, constraintType, constraintName, foreignTableName, foreignColumnName sql.NullString //nolint:lll - if err := rows.Scan(&tableName, &columnName, &dataType, &columnDefault, &isNullable, &constraintType, &constraintName, &foreignTableName, &foreignColumnName); err != nil { - return nil, fmt.Errorf("scanning row: %w", err) + if rowsErr := rows.Scan(&tableName, &columnName, &dataType, &columnDefault, &isNullable, &constraintType, &constraintName, &foreignTableName, &foreignColumnName); rowsErr != nil { + return nil, fmt.Errorf("scanning row: %w", rowsErr) } //nolin:nestif @@ -66,6 +66,7 @@ func (s *SQL) psqlSchema() (*string, error) { currentTable = tableName.String } + //nolint:nestif if columnName.Valid { schema += fmt.Sprintf(" Column: %s, Type: %s", columnName.String, dataType.String) @@ -88,8 +89,8 @@ func (s *SQL) psqlSchema() (*string, error) { } } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("rows error: %w", err) + if rowsErr := rows.Err(); rowsErr != nil { + return nil, fmt.Errorf("rows error: %w", rowsErr) } return &schema, nil