From fe0dcb1876dc95775b62e0363eefe8532e3c06c1 Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Mon, 26 Aug 2024 12:46:21 +0530 Subject: [PATCH 01/15] Handling subqueries fixed --- src/dataneuron/core/sql_query_filter.py | 54 +++++++++++++------------ tests/core/test_sql_query_filter.py | 8 ++-- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 7fc1fe1..bf55519 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -32,16 +32,14 @@ def _apply_filter_recursive(self, parsed, client_id): if self._is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) - if isinstance(parsed, Token) and parsed.ttype is DML: - return self._apply_filter_to_single_query(str(parsed), client_id) - elif self._contains_set_operation(parsed): - return self._handle_set_operation(parsed, client_id) - elif self._contains_subquery(parsed): - return self._handle_subquery(parsed, client_id) - else: - filtered_query = self._apply_filter_to_single_query( - str(parsed), client_id) - return self._handle_where_subqueries(sqlparse.parse(filtered_query)[0], client_id) + for tokens in parsed.token: + if isinstance(tokens, Token) and tokens.ttype is DML: + if self._contains_set_operation(parsed): + return self._handle_set_operation(parsed, client_id) + elif self._contains_subquery(parsed): + return self._handle_subquery(parsed, client_id) + else: + return self._apply_filter_to_single_query(str(parsed), client_id) def _contains_set_operation(self, parsed): set_operations = ('UNION', 'INTERSECT', 'EXCEPT') @@ -363,25 +361,35 @@ def _cleanup_whitespace(self, query: str) -> str: def _handle_subquery(self, parsed, client_id): result = [] tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + mainquery = [] for token in tokens: if isinstance(token, Identifier) and token.has_alias(): if isinstance(token.tokens[0], Parenthesis): + mainquery.append(" PLACEHOLDER ") subquery = token.tokens[0].tokens[1:-1] subquery_str = ' '.join(str(t) for t in subquery) filtered_subquery = self._apply_filter_recursive( sqlparse.parse(subquery_str)[0], client_id) alias = token.get_alias() - result.append(f"({filtered_subquery}) AS {alias}") + AS_keyword = next((t for t in token.tokens if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'AS'), None) # Checks for existence of 'AS' keyword + + if AS_keyword: + result.append(f"({filtered_subquery}) AS {alias}") + else: + result.append(f"({filtered_subquery}) {alias}") else: - result.append(str(token)) + mainquery.append(str(token)) + elif isinstance(token, Parenthesis): + mainquery.append(" PLACEHOLDER ") subquery = token.tokens[1:-1] subquery_str = ' '.join(str(t) for t in subquery) filtered_subquery = self._apply_filter_recursive( sqlparse.parse(subquery_str)[0], client_id) result.append(f"({filtered_subquery})") - elif isinstance(token, Where): + + elif isinstance(token, Where) and 'IN' in str(parsed): try: filtered_where = self._handle_where_subqueries( token, client_id) @@ -389,19 +397,15 @@ def _handle_subquery(self, parsed, client_id): except Exception as e: result.append(str(token)) else: - # Preserve whitespace tokens - if token.is_whitespace: - result.append(str(token)) - else: - # Add space before and after non-whitespace tokens, except for punctuation - if result and not result[-1].endswith(' ') and not str(token).startswith((')', ',', '.')): - result.append(' ') - result.append(str(token)) - if not str(token).endswith(('(', ',')): - result.append(' ') + mainquery.append(str(token)) - final_result = ''.join(result).strip() - return final_result + mainquery = ''.join(mainquery).strip() + if ' IN ' in str(parsed): + return f"{mainquery} {result[0]}" + else: + filtered_mainquery = self._apply_filter_to_single_query(mainquery, client_id) + query = filtered_mainquery.replace("PLACEHOLDER", result[0]) + return query def _handle_where_subqueries(self, where_clause, client_id): if self._is_cte_query(where_clause): diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 80b7eb4..674d281 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -98,10 +98,10 @@ def test_subquery_in_from(self): expected = 'SELECT * FROM (SELECT * FROM orders WHERE "orders"."user_id" = 1) AS subq' self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - # def test_subquery_in_join(self): - # query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' - # expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' - # self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + def test_subquery_in_join(self): + query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' + expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) def test_nested_subqueries(self): query = 'SELECT * FROM (SELECT * FROM (SELECT * FROM orders) AS inner_subq) AS outer_subq' From 0e44648b39c0eeec76eaa77f4154deaba742a7a4 Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Mon, 26 Aug 2024 12:49:47 +0530 Subject: [PATCH 02/15] Handling subqueries fixed (typo error) --- src/dataneuron/core/sql_query_filter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index bf55519..b4d1521 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -32,8 +32,8 @@ def _apply_filter_recursive(self, parsed, client_id): if self._is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) - for tokens in parsed.token: - if isinstance(tokens, Token) and tokens.ttype is DML: + for token in parsed.tokens: + if isinstance(token, Token) and token.ttype is DML: if self._contains_set_operation(parsed): return self._handle_set_operation(parsed, client_id) elif self._contains_subquery(parsed): From cfc4ba717c8bd7366ad062761a8d4188d95d4383 Mon Sep 17 00:00:00 2001 From: Nikhil Date: Wed, 28 Aug 2024 12:07:03 +0530 Subject: [PATCH 03/15] cte_subquery_ and cte_recursive fixex --- .../core/nlp_helpers/cte_handler.py | 13 +- src/dataneuron/core/sql_query_filter.py | 57 ++++++++- tests/core/test_sql_query_filter.py | 113 +++++++++--------- 3 files changed, 119 insertions(+), 64 deletions(-) diff --git a/src/dataneuron/core/nlp_helpers/cte_handler.py b/src/dataneuron/core/nlp_helpers/cte_handler.py index f57d6c4..7494228 100644 --- a/src/dataneuron/core/nlp_helpers/cte_handler.py +++ b/src/dataneuron/core/nlp_helpers/cte_handler.py @@ -49,6 +49,12 @@ def extract_main_query(parsed): def filter_cte(cte_part, filter_function, client_id): filtered_ctes = [] + is_recursive = False + + for token in cte_part.tokens: + if token.ttype is Keyword and token.value.upper() == 'RECURSIVE': + is_recursive = True + def process_cte(token): if isinstance(token, sqlparse.sql.Identifier): cte_name = token.get_name() @@ -57,7 +63,7 @@ def process_cte(token): # Remove outer parentheses inner_query_str = str(inner_query)[1:-1] filtered_inner_query = filter_function( - sqlparse.parse(inner_query_str)[0], client_id) + sqlparse.parse(inner_query_str)[0], client_id, cte_name) filtered_ctes.append(f"{cte_name} AS ({filtered_inner_query})") for token in cte_part.tokens: @@ -68,7 +74,10 @@ def process_cte(token): process_cte(token) if filtered_ctes: - filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) + if is_recursive: + filtered_cte_str = "WITH RECURSIVE " + ",\n".join(filtered_ctes) + else: + filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) else: filtered_cte_str = "" return filtered_cte_str diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index b4d1521..2927e71 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -28,14 +28,14 @@ def apply_client_filter(self, sql_query: str, client_id: int) -> str: return self._cleanup_whitespace(str(result)) - def _apply_filter_recursive(self, parsed, client_id): + def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if self._is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) for token in parsed.tokens: if isinstance(token, Token) and token.ttype is DML: if self._contains_set_operation(parsed): - return self._handle_set_operation(parsed, client_id) + return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) elif self._contains_subquery(parsed): return self._handle_subquery(parsed, client_id) else: @@ -225,7 +225,7 @@ def _inject_where_clause(self, parsed, where_clause): return str(parsed) - def _handle_set_operation(self, parsed, client_id): + def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None): print("Handling set operation") # Split the query into individual SELECT statements statements = [] @@ -251,9 +251,14 @@ def _handle_set_operation(self, parsed, client_id): # Apply the filter to each SELECT statement filtered_statements = [] for stmt in statements: - filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) - filtered_statements.append(filtered_stmt) - print(f"Filtered statement: {filtered_stmt}") + if is_cte: + filtered_stmt = self._apply_filter_to_single_CTE_query(stmt, client_id, cte_name) + filtered_statements.append(filtered_stmt) + print(f"Filtered statement: {filtered_stmt}") + else: + filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) + filtered_statements.append(filtered_stmt) + print(f"Filtered statement: {filtered_stmt}") # Reconstruct the query result = f" {set_operation} ".join(filtered_statements) @@ -512,3 +517,43 @@ def _extract_main_table(self, where_clause): if isinstance(token, Identifier): return token.get_real_name() return None + + def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: + parts = sql_query.split(' GROUP BY ') + main_query = parts[0] + + group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" + parsed = sqlparse.parse(main_query)[0] + tables_info = self._extract_tables_info(parsed) + + filters = [] + _table_ = [] + + for table_info in tables_info: + if table_info['name'] != cte_name: + table_dict = { + "name": table_info['name'], + "alias": table_info['alias'], + "schema": table_info['schema'] + } + _table_.append(table_dict) + + matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema']) + + if matching_table: + client_id_column = self.client_tables[matching_table] + table_reference = _table_[0]['alias'] or _table_[0]['name'] + + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') + + if filters: + where_clause = " AND ".join(filters) + if 'WHERE' in main_query.upper(): + where_parts = main_query.split('WHERE', 1) + result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}" + else: + result = f"{main_query} WHERE {where_clause}" + else: + result = main_query + + return result + group_by \ No newline at end of file diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 674d281..4589a69 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -123,7 +123,8 @@ def setUp(self): 'products': 'company_id', 'inventory.items': 'organization_id', 'items': 'organization_id', - 'customers': 'customer_id' + 'customers': 'customer_id', + 'categories': 'company_id' } self.filter = SQLQueryFilter( self.client_tables, schemas=['main', 'inventory']) @@ -217,61 +218,61 @@ def test_multiple_ctes(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) - # def test_cte_with_subquery(self): - # query = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # expected = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id - # WHERE "p"."company_id" = 1 - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) - - # def test_recursive_cte(self): - # query = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # ) - # SELECT * FROM category_tree - # ''' - # expected = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL AND "categories"."company_id" = 1 - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # WHERE "c"."company_id" = 1 - # ) - # SELECT * FROM category_tree - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) +def test_cte_with_subquery(self): + query = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + expected = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id + WHERE "p"."company_id" = 1 + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) + + def test_recursive_cte(self): + query = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM category_tree + ''' + expected = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL AND "categories"."company_id" = 1 + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + WHERE "c"."company_id" = 1 + ) + SELECT * FROM category_tree + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) if __name__ == '__main__': From 1aabcec0416ab74f9f25d5706bfd24ba3024fc65 Mon Sep 17 00:00:00 2001 From: nikhil3303 <153094190+nikhil3303@users.noreply.github.com> Date: Thu, 29 Aug 2024 10:59:04 +0530 Subject: [PATCH 04/15] typo fixed in recursive CTE test case --- tests/core/test_sql_query_filter.py | 52 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 4589a69..7f5e29c 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -218,32 +218,32 @@ def test_multiple_ctes(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) -def test_cte_with_subquery(self): - query = ''' - WITH top_products AS ( - SELECT p.id, p.name, SUM(o.quantity) as total_sold - FROM products p - JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id - GROUP BY p.id, p.name - ORDER BY total_sold DESC - LIMIT 10 - ) - SELECT * FROM top_products - ''' - expected = ''' - WITH top_products AS ( - SELECT p.id, p.name, SUM(o.quantity) as total_sold - FROM products p - JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id - WHERE "p"."company_id" = 1 - GROUP BY p.id, p.name - ORDER BY total_sold DESC - LIMIT 10 - ) - SELECT * FROM top_products - ''' - self.assertSQLEqual( - self.filter.apply_client_filter(query, 1), expected) + def test_cte_with_subquery(self): + query = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + expected = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id + WHERE "p"."company_id" = 1 + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) def test_recursive_cte(self): query = ''' From 37c6b178d6c927489cd27a15160e9c173a465987 Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Fri, 30 Aug 2024 11:16:16 +0530 Subject: [PATCH 05/15] Initial commit message for added testcases --- tests/core/test_sql_query_filter.py | 279 +++++++++++++++++++++++++++- 1 file changed, 278 insertions(+), 1 deletion(-) diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 7f5e29c..1d13e1a 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -1,5 +1,5 @@ import re -from dataneuron.core.sql_query_filter import SQLQueryFilter +from sql_query_filter import SQLQueryFilter import unittest @@ -274,6 +274,283 @@ def test_recursive_cte(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) +class TestSQLQueryFilterAdditional(unittest.TestCase): + def setUp(self): + self.client_tables = { + 'main.orders': 'user_id', + 'orders': 'user_id', + 'main.products': 'company_id', + 'products': 'company_id', + 'inventory.items': 'organization_id', + 'items': 'organization_id', + 'customers': 'customer_id', + 'categories': 'company_id' + } + self.filter = SQLQueryFilter(self.client_tables, schemas=['main', 'inventory']) + + def test_multiple_joins(self): + query = 'SELECT o.id, p.name, c.email FROM orders o JOIN products p ON o.product_id = p.id JOIN customers c ON o.user_id = c.id' + expected = 'SELECT o.id, p.name, c.email FROM orders o JOIN products p ON o.product_id = p.id JOIN customers c ON o.user_id = c.id WHERE "o"."user_id" = 1 AND "p"."company_id" = 1 AND "c"."customer_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_case_statement(self): + query = 'SELECT id, CASE WHEN total_amount > 1000 THEN "High" ELSE "Low" END AS order_value FROM orders' + expected = 'SELECT id, CASE WHEN total_amount > 1000 THEN "High" ELSE "Low" END AS order_value FROM orders WHERE "orders"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_subquery_in_select(self): + query = 'SELECT o.id, (SELECT COUNT(*) FROM products p WHERE p.id = o.product_id) AS product_count FROM orders o' + expected = 'SELECT o.id, (SELECT COUNT(*) FROM products p WHERE p.id = o.product_id AND "p"."company_id" = 1) AS product_count FROM orders o WHERE "o"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_having_clause(self): + query = 'SELECT product_id, COUNT(*) FROM orders GROUP BY product_id HAVING COUNT(*) > 5' + expected = 'SELECT product_id, COUNT(*) FROM orders WHERE "orders"."user_id" = 1 GROUP BY product_id HAVING COUNT(*) > 5' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_order_by_with_limit(self): + query = 'SELECT * FROM orders ORDER BY total_amount DESC LIMIT 10' + expected = 'SELECT * FROM orders WHERE "orders"."user_id" = 1 ORDER BY total_amount DESC LIMIT 10' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_union_with_order_by(self): + query = 'SELECT id FROM orders UNION SELECT id FROM products ORDER BY id' + expected = 'SELECT id FROM orders WHERE "orders"."user_id" = 1 UNION SELECT id FROM products WHERE "products"."company_id" = 1 ORDER BY id' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_subquery_with_aggregate(self): + query = 'SELECT * FROM orders WHERE total_amount > (SELECT AVG(total_amount) FROM orders)' + expected = 'SELECT * FROM orders WHERE total_amount > (SELECT AVG(total_amount) FROM orders WHERE "orders"."user_id" = 1) AND "orders"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_complex_join_with_subquery(self): + query = ''' + SELECT o.id, p.name + FROM orders o + JOIN (SELECT id, name FROM products WHERE price > 100) p ON o.product_id = p.id + WHERE o.status = 'completed' + ''' + expected = ''' + SELECT o.id, p.name + FROM orders o + JOIN (SELECT id, name FROM products WHERE price > 100 AND "products"."company_id" = 1) p ON o.product_id = p.id + WHERE o.status = 'completed' AND "o"."user_id" = 1 + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_complex_nested_subqueries(self): + query = ''' + SELECT * + FROM orders o + WHERE o.product_id IN ( + SELECT id + FROM products + WHERE category_id IN ( + SELECT id + FROM categories + WHERE name LIKE 'Electronics%' + ) + ) AND o.user_id IN ( + SELECT user_id + FROM ( + SELECT user_id, AVG(total_amount) as avg_order + FROM orders + GROUP BY user_id + HAVING AVG(total_amount) > 1000 + ) high_value_customers + ) + ''' + expected = ''' + SELECT * + FROM orders o + WHERE o.product_id IN ( + SELECT id + FROM products + WHERE category_id IN ( + SELECT id + FROM categories + WHERE name LIKE 'Electronics%' + AND "categories"."company_id" = 1 + ) + AND "products"."company_id" = 1 + ) AND o.user_id IN ( + SELECT user_id + FROM ( + SELECT user_id, AVG(total_amount) as avg_order + FROM orders + WHERE "orders"."user_id" = 1 + GROUP BY user_id + HAVING AVG(total_amount) > 1000 + ) high_value_customers + ) + AND "o"."user_id" = 1 + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_group_by_having_order_by(self): + query = ''' + SELECT product_id, COUNT(*) as order_count, SUM(total_amount) as total_sales + FROM orders + GROUP BY product_id + HAVING COUNT(*) > 10 + ORDER BY total_sales DESC + LIMIT 5 + ''' + expected = ''' + SELECT product_id, COUNT(*) as order_count, SUM(total_amount) as total_sales + FROM orders + WHERE "orders"."user_id" = 1 + GROUP BY product_id + HAVING COUNT(*) > 10 + ORDER BY total_sales DESC + LIMIT 5 + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_different_data_types_in_where(self): + query = ''' + SELECT * + FROM orders + WHERE order_date > '2023-01-01' + AND total_amount > 100.50 + AND status IN ('completed', 'shipped') + AND is_priority = TRUE + ''' + expected = ''' + SELECT * + FROM orders + WHERE order_date > '2023-01-01' + AND total_amount > 100.50 + AND status IN ('completed', 'shipped') + AND is_priority = TRUE + AND "orders"."user_id" = 1 + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_multi_schema_query(self): + query = ''' + SELECT o.id, p.name, i.quantity + FROM main.orders o + JOIN main.products p ON o.product_id = p.id + JOIN inventory.items i ON p.id = i.product_id + ''' + expected = ''' + SELECT o.id, p.name, i.quantity + FROM main.orders o + JOIN main.products p ON o.product_id = p.id + JOIN inventory.items i ON p.id = i.product_id + WHERE "o"."user_id" = 1 AND "p"."company_id" = 1 AND "i"."organization_id" = 1 + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + +class TestSQLQueryFilterAdditionalCTE(unittest.TestCase): + def setUp(self): + self.client_tables = { + 'main.orders': 'user_id', + 'orders': 'user_id', + 'main.products': 'company_id', + 'products': 'company_id', + 'inventory.items': 'organization_id', + 'items': 'organization_id', + 'customers': 'customer_id', + 'categories': 'company_id' + } + self.filter = SQLQueryFilter( + self.client_tables, schemas=['main', 'inventory']) + + def assertSQLEqual(self, first, second, msg=None): + def normalize_sql(sql): + # Remove all whitespace + sql = re.sub(r'\s+', '', sql) + # Convert to lowercase + return sql.lower() + + normalized_first = normalize_sql(first) + normalized_second = normalize_sql(second) + self.assertEqual(normalized_first, normalized_second, msg) + + def test_cte_with_union(self): + query = ''' + WITH combined_data AS ( + SELECT id, 'order' AS type FROM orders + UNION ALL + SELECT id, 'product' AS type FROM products + ) + SELECT * FROM combined_data + ''' + expected = ''' + WITH combined_data AS ( + SELECT id, 'order' AS type FROM orders WHERE "orders"."user_id" = 1 + UNION ALL + SELECT id, 'product' AS type FROM products WHERE "products"."company_id" = 1 + ) + SELECT * FROM combined_data + ''' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_set_operations_with_cte(self): + query = ''' + WITH order_summary AS ( + SELECT user_id, COUNT(*) as order_count + FROM orders + GROUP BY user_id + ) + SELECT * FROM order_summary + UNION + SELECT company_id as user_id, COUNT(*) as product_count + FROM products + GROUP BY company_id + ''' + expected = ''' + WITH order_summary AS ( + SELECT user_id, COUNT(*) as order_count + FROM orders + WHERE "orders"."user_id" = 1 + GROUP BY user_id + ) + SELECT * FROM order_summary + UNION + SELECT company_id as user_id, COUNT(*) as product_count + FROM products + WHERE "products"."company_id" = 1 + GROUP BY company_id + ''' + self.assertSQLEqual(self.filter.apply_client_filter(query, 1), expected) + + def test_recursive_cte_with_join(self): + query = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT ct.*, p.name as product_name + FROM category_tree ct + LEFT JOIN products p ON ct.id = p.category_id + ''' + expected = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL AND "categories"."company_id" = 1 + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + WHERE "c"."company_id" = 1 + ) + SELECT ct.*, p.name as product_name + FROM category_tree ct + LEFT JOIN products p ON ct.id = p.category_id + WHERE "p"."company_id" = 1 + ''' + self.assertSQLEqual(self.filter.apply_client_filter(query, 1), expected) + if __name__ == '__main__': unittest.main() From 2bf245273c6f14477d8b3dc478e61790b8179e02 Mon Sep 17 00:00:00 2001 From: nikhil3303 <153094190+nikhil3303@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:20:14 +0530 Subject: [PATCH 06/15] Fixed wrong import statement in test_sql_query_filter.py Original - "from dataneuron.core.sql_query_filter import SQLQueryFilter" changed to "from sql_query_filter import SQLQueryFilter" for testing purposes but forgot to change it back before commit --- tests/core/test_sql_query_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 1d13e1a..4700066 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -1,5 +1,5 @@ import re -from sql_query_filter import SQLQueryFilter +from dataneuron.core.sql_query_filter import SQLQueryFilter import unittest From 1a994fbae0d3c126ed8183cb641e14f8747145fb Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Fri, 13 Sep 2024 11:07:09 +0530 Subject: [PATCH 07/15] Logic added for handing additional subquery cases - only the detection method is added not the handling method --- src/dataneuron/core/sql_query_filter.py | 98 +++++++++++++++---------- 1 file changed, 59 insertions(+), 39 deletions(-) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 2927e71..d8348ea 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -303,48 +303,68 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: def _contains_subquery(self, parsed): tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + i = 0 + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} + case_end_keywords = {'WHEN', 'THEN', 'ELSE'} + where_keywords = {'IN', 'EXISTS', 'ANY', 'ALL', 'NOT IN'} + + while i < len(tokens): + token = tokens[i] + + if token.ttype is DML and token.value.upper() == 'SELECT': + # Find the index of the FROM clause + k = i + 1 + while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + # Check for CASE expressions or inline subqueries between SELECT and FROM + for j in range(i + 1, k): + next_token = tokens[j] + if str(next_token).startswith("CASE"): + case_token_list = [t for t in TokenList(next_token).flatten() if t.ttype == Keyword] + if any(k.value.upper() in case_end_keywords for k in case_token_list): + return True + elif "(" in str(next_token) and ")" in str(next_token): + return True - for i, token in enumerate(tokens): - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - return True - elif isinstance(token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in token.tokens): - return True - # Recursively check inside parentheses - if self._contains_subquery(token): - return True - elif isinstance(token, Where): - in_found = False - for j, sub_token in enumerate(token.tokens): - if in_found: - if isinstance(sub_token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in sub_token.tokens): - return True - elif hasattr(sub_token, 'ttype') and not sub_token.is_whitespace: - # Check if the token is a parenthesis-like structure - if '(' in sub_token.value and ')' in sub_token.value: - if 'SELECT' in sub_token.value.upper(): - return True - # If we find a non-whitespace token that's not a parenthesis, reset in_found - in_found = False - elif hasattr(sub_token, 'ttype') and sub_token.ttype is Keyword and sub_token.value.upper() == 'IN': - in_found = True - elif isinstance(sub_token, Comparison): - for item in sub_token.tokens: - if isinstance(item, Parenthesis): - if self._contains_subquery(item): - return True - elif hasattr(token, 'ttype') and token.ttype is Keyword and token.value.upper() == 'IN': - next_token = tokens[i+1] if i+1 < len(tokens) else None - if next_token: - if isinstance(next_token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in next_token.tokens): + from_index = k + where_index = None + + # Find the WHERE clause if any + k = from_index + 1 + while k < len(tokens): + if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': + where_index = k + break + k += 1 + + end_index = where_index if where_index else len(tokens) + + # Check for set operations, joins, or inline subqueries after FROM + for j in range(from_index + 1, end_index): + next_token = tokens[j] + if "(" in str(next_token) and ")" in str(next_token): + if any(op in str(next_token).upper() for op in f"{set_operations}" or f"{set_operations} ALL"): # Set operations return True - elif hasattr(next_token, 'value') and '(' in next_token.value and ')' in next_token.value: - if 'SELECT' in next_token.value.upper(): + elif str(next_token).upper() in joins: #JOINs operations return True - + else: + return True #Inline + + # Process the WHERE clause if present + if where_index: + for j in range(where_index + 1, len(tokens)): + next_token = tokens[j] + if "(" in str(next_token) and ")" in str(next_token): #Inline + return True + if str(next_token).startswith("CASE"): # Case END block inside WHERE + case_token_list = [t for t in TokenList(next_token).flatten() if t.ttype == Keyword] + if any(k.value.upper() in case_end_keywords for k in case_token_list): + return True + if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: + return True # Where keywords - IN, EXISTS, ANY, etc + i += 1 return False def _cleanup_whitespace(self, query: str) -> str: From 51224507db9770ca3a8bd65b9e91162b26cba4c1 Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Fri, 13 Sep 2024 17:32:53 +0530 Subject: [PATCH 08/15] Logic for GROUPBY, HAVING, ORDERBY added (only to contains_subquery method and not the handle_subquery method --- src/dataneuron/core/sql_query_filter.py | 76 ++++++++++++++----------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index d8348ea..46c69dc 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -302,70 +302,82 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: return result + group_by def _contains_subquery(self, parsed): - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - i = 0 + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} case_end_keywords = {'WHEN', 'THEN', 'ELSE'} where_keywords = {'IN', 'EXISTS', 'ANY', 'ALL', 'NOT IN'} + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + i = 0 while i < len(tokens): token = tokens[i] - + if token.ttype is DML and token.value.upper() == 'SELECT': - # Find the index of the FROM clause - k = i + 1 - while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k = i + 1 + while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): # Find the index of the FROM clause k += 1 - # Check for CASE expressions or inline subqueries between SELECT and FROM - for j in range(i + 1, k): - next_token = tokens[j] - if str(next_token).startswith("CASE"): - case_token_list = [t for t in TokenList(next_token).flatten() if t.ttype == Keyword] - if any(k.value.upper() in case_end_keywords for k in case_token_list): - return True - elif "(" in str(next_token) and ")" in str(next_token): - return True - from_index = k where_index = None + k = from_index + 1 - # Find the WHERE clause if any - k = from_index + 1 while k < len(tokens): - if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': + if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': # Find the WHERE clause if any where_index = k break k += 1 - end_index = where_index if where_index else len(tokens) - # Check for set operations, joins, or inline subqueries after FROM - for j in range(from_index + 1, end_index): + for j in range(i + 1, k): # Between SELECT and FROM block next_token = tokens[j] if "(" in str(next_token) and ")" in str(next_token): - if any(op in str(next_token).upper() for op in f"{set_operations}" or f"{set_operations} ALL"): # Set operations + if re.search(r'\bCASE\b(\s+WHEN\b.*?\bTHEN\b.*?)+(\s+ELSE\b.*?)?(?=\s+END\b)', str(next_token), re.DOTALL): return True - elif str(next_token).upper() in joins: #JOINs operations + else: + return True + + for j in range(from_index + 1, end_index): # FROM block checking for subqueries inside + next_token = str(tokens[j]).upper() + if "(" in next_token and ")" in next_token: + if any(op in next_token for op in set_operations): # Set operations return True + elif any(join in next_token for join in joins): # Joins + return True # This condition verifies that at least one statement in JOINs is a subquery else: - return True #Inline + return True # Inline subquery - # Process the WHERE clause if present - if where_index: + if where_index: # Procced only if WHERE exists for j in range(where_index + 1, len(tokens)): next_token = tokens[j] - if "(" in str(next_token) and ")" in str(next_token): #Inline + token_str = str(next_token).upper() + if "(" in token_str and ")" in token_str: # Inline subquery return True - if str(next_token).startswith("CASE"): # Case END block inside WHERE + if str(next_token).startswith("CASE"): # CASE END block inside WHERE case_token_list = [t for t in TokenList(next_token).flatten() if t.ttype == Keyword] if any(k.value.upper() in case_end_keywords for k in case_token_list): return True - if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: - return True # Where keywords - IN, EXISTS, ANY, etc + if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: # WHERE keywords + return True + + for j in range(end_index, len(tokens)): # WHERE block checking for subqueries + next_token = tokens[j] + token_str = str(next_token).upper() + if next_token.ttype == Keyword and next_token.value.upper() in end_keywords: + for m in range(j + 1, len(tokens)): + after_end_keyword_token = tokens[m] + after_end_token_str = str(after_end_keyword_token).upper() + if "(" in after_end_token_str and ")" in after_end_token_str: # Inline subquery + return True + if after_end_keyword_token.ttype == Keyword and after_end_keyword_token.value.upper() in where_keywords: # Keywords like IN, EXISTS, etc. + return True + if str(after_end_keyword_token).startswith("CASE"): # CASE END block after GROUP BY, etc. + case_token_list = [t for t in TokenList(after_end_keyword_token).flatten() if t.ttype == Keyword] + if any(k.value.upper() in case_end_keywords for k in case_token_list): + return True i += 1 - return False + return None def _cleanup_whitespace(self, query: str) -> str: # Split the query into lines From 7c3d38d719494f1691ac269b22808e07eaed4a2a Mon Sep 17 00:00:00 2001 From: Nikhil Kameshwaran Date: Tue, 17 Sep 2024 12:59:30 +0530 Subject: [PATCH 09/15] _handle_set_operations method is now able to handle cases where the SELECT statements are enclosed in parenthesis --- src/dataneuron/core/sql_query_filter.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 46c69dc..696aa24 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -256,9 +256,16 @@ def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_nam filtered_statements.append(filtered_stmt) print(f"Filtered statement: {filtered_stmt}") else: - filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) - filtered_statements.append(filtered_stmt) - print(f"Filtered statement: {filtered_stmt}") + match = re.search(r'\(([^()]*)\)', stmt) + if match: + extracted_part = match.group(1) + filtered_stmt = stmt.replace(extracted_part, self._apply_filter_to_single_query(extracted_part, client_id)) + filtered_statements.append(filtered_stmt) + #print(f"Filtered statement: {filtered_stmt}") + else: + filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) + filtered_statements.append(filtered_stmt) + #print(f"Filtered statement: {filtered_stmt}") # Reconstruct the query result = f" {set_operation} ".join(filtered_statements) From ee2585f3b800f1b938ca42c6970a3090a5125944 Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Wed, 25 Sep 2024 01:15:16 +0530 Subject: [PATCH 10/15] Previous test cases work without throwing errors, CTE now compatible with the new logic, Added separate python files for detecting and handling subqueries - Detection logic not working for FROM subqueries(Detection logic to be modified) - end_keywords subquery detection and handling removed for now, Removed dead functions --- .../core/nlp_helpers/is_subquery.py | 73 +++ .../core/nlp_helpers/subquery_handler.py | 218 +++++++ src/dataneuron/core/sql_query_filter.py | 598 ++++-------------- 3 files changed, 412 insertions(+), 477 deletions(-) create mode 100644 src/dataneuron/core/nlp_helpers/is_subquery.py create mode 100644 src/dataneuron/core/nlp_helpers/subquery_handler.py diff --git a/src/dataneuron/core/nlp_helpers/is_subquery.py b/src/dataneuron/core/nlp_helpers/is_subquery.py new file mode 100644 index 0000000..83158f6 --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/is_subquery.py @@ -0,0 +1,73 @@ +from sqlparse.tokens import DML, Keyword, Whitespace +from sqlparse.sql import TokenList, Parenthesis, Function +import re + +def _contains_subquery(parsed): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} + case_end_keywords = {'WHEN', 'THEN', 'ELSE'} + where_keywords = {'IN', 'EXISTS', 'ANY', 'ALL', 'NOT IN'} + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + + def is_subquery(token): + if isinstance(token, Parenthesis): + inner_tokens = token.tokens + for inner_token in inner_tokens: + if inner_token.ttype is not Whitespace: + return inner_token.ttype is DML and inner_token.value.upper() == 'SELECT' + return False + + i = 0 + while i < len(tokens): + token = tokens[i] + + if token.ttype is DML and token.value.upper() == 'SELECT': + k = i + 1 + while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + where_index = None + k = from_index + 1 + + while k < len(tokens): + if 'WHERE' in str(tokens[k]): + where_index = k + break + k += 1 + end_index = where_index if where_index else len(tokens) + + for j in range(i + 1, from_index): # Between SELECT and FROM block + next_token = tokens[j] + if is_subquery(next_token): + return True + if isinstance(next_token, Function): + if re.search(r'\bCASE\b(\s+WHEN\b.?\bTHEN\b.?)+(\s+ELSE\b.*?)?(?=\s+END\b)', str(next_token), re.DOTALL): + # Check for subquery within CASE statement + if any(is_subquery(t) for t in next_token.tokens): + return True + + for j in range(from_index + 1, end_index): # FROM block checking for subqueries inside + next_token = tokens[j] + if is_subquery(next_token): + return True + if isinstance(next_token, Function): + if any(op in next_token.value.upper() for op in set_operations): # Set operations + return True + if any(join in str(next_token).upper() for join in joins): # Joins + if j+1 < len(tokens) and is_subquery(tokens[j+1]): + return True + + if where_index: # Proceed only if WHERE exists + for j in range(where_index + 1, len(tokens)): + next_token = tokens[j] + if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: + if j+1 < len(tokens) and is_subquery(tokens[j+1]): + return True + elif is_subquery(next_token): + return True + + i += 1 + return False \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py new file mode 100644 index 0000000..66ef6ec --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -0,0 +1,218 @@ +import sqlparse +from sqlparse.sql import TokenList +from sqlparse.tokens import Keyword, DML +from sql_query_filter import SQLQueryFilter +import re + +def _handle_subquery(parsed, client_id): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + i = 0 + while i < len(tokens): + token = tokens[i] + + if token.ttype is DML and token.value.upper() == 'SELECT': + select_start = i + 1 + k = select_start + + while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): # Find the index of the FROM clause + k += 1 + + from_index = k + where_index = None + k = from_index + 1 + + while k < len(tokens): # Find the WHERE clause if any + if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': + where_index = k + break + k += 1 + + end_index = len(tokens) + + SELECT_block = TokenList(tokens[select_start:from_index]) + FROM_block = TokenList(tokens[from_index + 1:where_index]) if where_index else TokenList(tokens[from_index + 1:end_index]) + WHERE_block = TokenList(tokens[where_index + 1:end_index]) if where_index else None + + i = end_index # Move the index to the end of the processed part + else: + i += 1 + + SELECT_dict = SELECT_subquery(parsed, client_id) + FROM_dict = FROM_subquery(parsed, client_id) + WHERE_dict = WHERE_subquery(parsed, client_id) + + subquery_dict = { + "subqueries": SELECT_dict['subquery_list'] + FROM_dict['subquery_list'] + WHERE_dict['subquery_list'], + "filtered subqueries": SELECT_dict['filtered_subquery'] + FROM_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'], + "placeholder names": SELECT_dict['placeholder_value'] + FROM_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + } + + for i in range(len(subquery_dict['filtered subqueries'])): + mainquery_str = str(parsed).replace(f"({subquery_dict['subqueries'][i]})", subquery_dict['placeholder names'][i]) if i == 0 else mainquery_str.replace(f"({subquery_dict['subqueries'][i]})", subquery_dict['placeholder names'][i]) + if len(subquery_dict['subqueries']) == 1: + filtered_mainquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(mainquery_str)[0], client_id) # Handle the case where there is only one subquery + + elif i == len(subquery_dict['subqueries']) - 1: + filtered_mainquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(mainquery_str)[0], client_id) # Apply filtering to the main query for the last iteration in case of multiple subqueries + + elif i == 0: + filtered_mainquery = mainquery_str # For the first iteration, just keep the mainquery_str as it is + + for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): + filtered_mainquery = filtered_mainquery.replace(placeholder, f"({str(filtered_subquery)})") + + return filtered_mainquery + + +def SELECT_subquery(SELECT_block, client_id): + subqueries = re.findall(r'\(([^()]+(?:\([^()]\))[^()]*)\)', str(SELECT_block)) + filtered_dict = { + 'subquery_list': subqueries, + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for i, subquery in enumerate(filtered_dict['subquery_list']): # Apply filters to extracted subqueries + placeholder = f"" + filtered_subquery = SQLQueryFilter._apply_filter_to_single_query(subquery, client_id) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + +def FROM_subquery(FROM_block, client_id): + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} + subquery_dict = { + "inline subquery": [], + "join subquery": [], + "set operations": [], + } + + FROM_block_tokens = [token for token in FROM_block.tokens if not token.is_whitespace] + + for i, token in enumerate(FROM_block_tokens): + if token.ttype is Keyword and token.value.upper() in joins: # JOINs operations + if i > 0 and "(" in str(FROM_block_tokens[i-1]) and ")" in str(FROM_block_tokens[i-1]): # Select only subqueries + subquery = FROM_block_tokens[i-1].value + if subquery not in subquery_dict['join subquery']: + subquery_dict['join subquery'].append(subquery) + + if i < len(FROM_block_tokens) - 1 and "(" in str(FROM_block_tokens[i+1]) and ")" in str(FROM_block_tokens[i+1]): + subquery = FROM_block_tokens[i+1].value + if subquery not in subquery_dict['join subquery']: + subquery_dict['join subquery'].append(subquery) + i += 1 + + i = 0 + while i < len(FROM_block.tokens): # SET operation + if FROM_block.tokens[i].ttype is Keyword and FROM_block.tokens[i].value.upper() in set_operations: + subquery_dict['set operations'].append(str(FROM_block)) + break + i += 1 + + filtered_dict = { + 'subquery_list': subquery_dict['inline subquery'] + subquery_dict['join subquery'], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for i in range(len(filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = SQLQueryFilter._apply_filter_recursive( + sqlparse.parse(filtered_dict['subquery_list'][i])[0], client_id + ) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + for j in range(len(subquery_dict['set operations'])): + placeholder = f"" + filtered_dict['subquery_list'].append(subquery_dict['set operations'][j]) + filtered_dict['placeholder_value'].append(placeholder) + + filtered_subquery = SQLQueryFilter._handle_set_operation( + sqlparse.parse(subquery_dict['set operations'][j])[0], client_id + ) + filtered_dict['filtered_subquery'].append(filtered_subquery) + return filtered_dict + + +def WHERE_subquery(parsed, client_id): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + i = 0 + + subquery_dict = { + "in_subquery": [], + "not_in_subquery": [], + "exists_subquery": [], + "not_exists_subquery": [], + "any_subquery": [], + "all_subquery": [], + "inline subquery": [], + } + + def subquery_extractor(next_token): + for t in next_token[1]: + if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'IN': + next_token_in = next_token[1].token_next(next_token[1].token_index(t)) + if "(" in str(next_token_in) and ")" in str(next_token_in): + subquery_dict['in_subquery'].append(str(TokenList(next_token_in[0][1:-1]))) + + elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'NOT' and next_token[1].token_next(next_token[1].token_index(t))[1].value.upper() == 'IN': + next_token_not_in = next_token[1].token_next(next_token[1].token_index(t) + 1) + if "(" in str(next_token_not_in) and ")" in str(next_token_not_in): + subquery_dict['not_in_subquery'].append(str(TokenList(next_token_not_in[0][1:-1]))) + + elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'EXISTS': + next_token_exists = next_token[1].token_next(next_token[1].token_index(t)) + if "(" in str(next_token_exists) and ")" in str(next_token_exists): + subquery_dict['exists_subquery'].append(str(TokenList(next_token_exists[0][1:-1]))) + + elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'NOT' and next_token[1].token_next(next_token[1].token_index(t))[1].value.upper() == 'EXISTS': + next_token_not_exists = next_token[1].token_next(next_token[1].token_index(t) + 1) + if "(" in str(next_token_not_exists) and ")" in str(next_token_not_exists): + subquery_dict['not_exists_subquery'].append(str(TokenList(next_token_not_exists[0][1:-1]))) + + elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'ANY': + next_token_any = next_token[1].token_next(next_token[1].token_index(t)) + if "(" in str(next_token_any) and ")" in str(next_token_any): + subquery_dict['any_subquery'].append(str(TokenList(next_token_any[0][1:-1]))) + + elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'ALL': + next_token_all = next_token[1].token_next(next_token[1].token_index(t)) + if "(" in str(next_token_all) and ")" in str(next_token_all): + subquery_dict['all_subquery'].append(str(TokenList(next_token_all[0][1:-1]))) + + elif "(" in str(t) and ")" in str(t): + subquery_dict['inline subquery'].append(str(TokenList(t[0][1:-1]))) + + else: + SQLQueryFilter._apply_filter_to_single_query(str(parsed), client_id) + + while i < len(tokens): + token = parsed.tokens[i] + if token.ttype == sqlparse.tokens.Keyword and token.value.upper() == 'WHERE': + next_token = parsed.token_next(i) + subquery_extractor(next_token) + i += 1 + + filtered_dict = { + 'subquery_list': subquery_dict['in_subquery'] + subquery_dict['not_in_subquery'] + + subquery_dict['exists_subquery'] + subquery_dict['not_exists_subquery'] + + subquery_dict['any_subquery'] + subquery_dict['all_subquery'] + + subquery_dict['inline subquery'], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for i in range(len(filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(filtered_dict['subquery'][i])[0], client_id) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + + \ No newline at end of file diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 696aa24..d4d4087 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -5,6 +5,8 @@ from typing import List, Dict, Optional from .nlp_helpers.cte_handler import handle_cte_query from .nlp_helpers.is_cte import is_cte_query +from .nlp_helpers.is_subquery import _contains_subquery +from .nlp_helpers.subquery_handler import _handle_subquery class SQLQueryFilter: @@ -18,221 +20,42 @@ def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], def apply_client_filter(self, sql_query: str, client_id: int) -> str: self.filtered_tables = set() parsed = sqlparse.parse(sql_query)[0] - is_cte = self._is_cte_query(parsed) if is_cte: return handle_cte_query(parsed, self._apply_filter_recursive, client_id) else: result = self._apply_filter_recursive(parsed, client_id) - return self._cleanup_whitespace(str(result)) - + def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if self._is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) - - for token in parsed.tokens: - if isinstance(token, Token) and token.ttype is DML: - if self._contains_set_operation(parsed): - return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) - elif self._contains_subquery(parsed): - return self._handle_subquery(parsed, client_id) - else: - return self._apply_filter_to_single_query(str(parsed), client_id) - - def _contains_set_operation(self, parsed): - set_operations = ('UNION', 'INTERSECT', 'EXCEPT') - - # Check if parsed is a TokenList (has tokens attribute) - if hasattr(parsed, 'tokens'): - tokens = parsed.tokens else: - # If it's a single Token, wrap it in a list - tokens = [parsed] - - for i, token in enumerate(tokens): - if token.ttype is Keyword: - # Check for 'UNION ALL' as a single token - if token.value.upper() == 'UNION ALL': - print("Set operation found: UNION ALL") - return True - # Check for 'UNION', 'INTERSECT', 'EXCEPT' followed by 'ALL' - if token.value.upper() in set_operations: - next_token = parsed.token_next(i) if hasattr( - parsed, 'token_next') else None - if next_token and next_token[1].value.upper() == 'ALL': - print(f"Set operation found: {token.value} ALL") - return True + for token in parsed.tokens: + if isinstance(token, Token) and token.ttype is DML: + if self._contains_set_operation(parsed) and not _contains_subquery(parsed): + return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) + elif _contains_subquery(parsed): + return _handle_subquery(parsed, client_id) else: - print(f"Set operation found: {token.value}") - return True - return False - - def _extract_from_clause_tables(self, parsed, tables_info): - from_seen = False - for token in parsed.tokens: - if from_seen: - if isinstance(token, Identifier): - tables_info.append(self._parse_table_identifier(token)) - elif isinstance(token, IdentifierList): - for identifier in token.get_identifiers(): - if isinstance(identifier, Identifier): - tables_info.append( - self._parse_table_identifier(identifier)) - elif token.ttype is Keyword and token.value.upper() in ('WHERE', 'GROUP', 'ORDER', 'LIMIT'): - break - elif token.ttype is Keyword and token.value.upper() == 'FROM': - from_seen = True - elif token.ttype is Keyword and token.value.upper() == 'JOIN': - tables_info.append(self._parse_table_identifier( - parsed.token_next(token)[1])) - - def _extract_where_clause_tables(self, parsed, tables_info): - where_clause = next( - (token for token in parsed.tokens if isinstance(token, Where)), None) - if where_clause: - for token in where_clause.tokens: - if isinstance(token, Comparison): - for item in token.tokens: - if isinstance(item, Identifier): - if '.' in item.value: - schema, name = item.value.split('.', 1) - tables_info.append( - {'name': name, 'schema': schema, 'alias': None}) - elif isinstance(item, Parenthesis): - subquery = ' '.join(str(t) - for t in item.tokens[1:-1]) - subquery_parsed = sqlparse.parse(subquery)[0] - self._extract_from_clause_tables( - subquery_parsed, tables_info) - - def _extract_cte_tables(self, parsed, tables_info): - cte_start = next((i for i, token in enumerate( - parsed.tokens) if token.ttype is Keyword and token.value.upper() == 'WITH'), None) - if cte_start is not None: - for token in parsed.tokens[cte_start:]: - if isinstance(token, sqlparse.sql.Identifier) and token.has_alias(): - cte_name = token.get_alias() - tables_info.append( - {'name': cte_name, 'schema': None, 'alias': None}) - cte_query = token.tokens[-1] - if isinstance(cte_query, sqlparse.sql.Parenthesis): - # Remove outer parentheses and parse the CTE query - cte_parsed = sqlparse.parse(str(cte_query)[1:-1])[0] - # Recursively extract tables from the CTE query - self._extract_tables_info(cte_parsed, tables_info) - elif token.ttype is DML and token.value.upper() == 'SELECT': - break - - def _extract_tables_info(self, parsed, tables_info=None): - if tables_info is None: - tables_info = [] - - self._extract_from_clause_tables(parsed, tables_info) - self._extract_where_clause_tables(parsed, tables_info) - self._extract_cte_tables(parsed, tables_info) - - return tables_info - - def _extract_nested_subqueries(self, parsed, tables_info): + return self._apply_filter_to_single_query(str(parsed), client_id) + + def _contains_set_operation(self, parsed): + set_operations = ('UNION', 'INTERSECT', 'EXCEPT') + for token in parsed.tokens: - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - subquery = token.tokens[0].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - subquery_parsed = sqlparse.parse(subquery_str)[0] - self._extract_from_clause_tables( - subquery_parsed, tables_info) - self._extract_where_clause_tables( - subquery_parsed, tables_info) - self._extract_nested_subqueries( - subquery_parsed, tables_info) - - def _parse_table_identifier(self, identifier): - schema = None - alias = None - name = self._strip_quotes(str(identifier)) - - if identifier.has_alias(): - alias = self._strip_quotes(identifier.get_alias()) - name = self._strip_quotes(identifier.get_real_name()) - - if '.' in name: - parts = name.split('.') - if len(parts) == 2: - schema, name = parts - name = f"{schema}.{name}" if schema else name - - return {'name': name, 'schema': schema, 'alias': alias} - - def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: - possible_names = [ - f"{schema}.{table_name}" if schema else table_name, - table_name, - ] + [f"{s}.{table_name}" for s in self.schemas] - - for name in possible_names: - if self._case_insensitive_get(self.client_tables, name) is not None: - return name - return None - - def _case_insensitive_get(self, dict_obj: Dict[str, str], key: str) -> Optional[str]: - if self.case_sensitive: - return dict_obj.get(key) - return next((v for k, v in dict_obj.items() if k.lower() == key.lower()), None) - - def _strip_quotes(self, identifier: str) -> str: - return identifier.strip('"').strip("'").strip('`') - - def _quote_identifier(self, identifier: str) -> str: - return f'"{identifier}"' - - def _inject_where_clause(self, parsed, where_clause): - - where_index = next((i for i, token in enumerate(parsed.tokens) - if token.ttype is Keyword and token.value.upper() == 'WHERE'), None) - - if where_index is not None: - # Find the end of the existing WHERE clause - end_where_index = len(parsed.tokens) - 1 - for i in range(where_index + 1, len(parsed.tokens)): - token = parsed.tokens[i] - if token.ttype is Keyword and token.value.upper() in ('GROUP', 'ORDER', 'LIMIT'): - end_where_index = i - 1 - break - - # Insert our condition at the end of the existing WHERE clause - parsed.tokens.insert(end_where_index + 1, Token(Whitespace, ' ')) - parsed.tokens.insert(end_where_index + 2, Token(Keyword, 'AND')) - parsed.tokens.insert(end_where_index + 3, Token(Whitespace, ' ')) - parsed.tokens.insert(end_where_index + 4, - Token(Name, where_clause)) - else: - # Find the position to insert the WHERE clause - insert_position = len(parsed.tokens) - for i, token in enumerate(parsed.tokens): - if token.ttype is Keyword and token.value.upper() in ('GROUP', 'ORDER', 'LIMIT'): - insert_position = i - break - - # Insert the new WHERE clause - parsed.tokens.insert(insert_position, Token(Whitespace, ' ')) - parsed.tokens.insert(insert_position + 1, Token(Keyword, 'WHERE')) - parsed.tokens.insert(insert_position + 2, Token(Whitespace, ' ')) - parsed.tokens.insert(insert_position + 3, - Token(Name, where_clause)) - - return str(parsed) - + if token.ttype is Keyword and (token.value.upper() in set_operations or token.value.upper() in {op + ' ALL' for op in set_operations}): + return True + return False + def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None): - print("Handling set operation") - # Split the query into individual SELECT statements + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} statements = [] current_statement = [] set_operation = None for token in parsed.tokens: - if token.ttype is Keyword and token.value.upper() in ('UNION', 'INTERSECT', 'EXCEPT', 'UNION ALL'): + if token.ttype is Keyword and (token.value.upper() in set_operations or token.value.upper() in {op + ' ALL' for op in set_operations}): if current_statement: statements.append(''.join(str(t) for t in current_statement).strip()) @@ -245,10 +68,6 @@ def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_nam statements.append(''.join(str(t) for t in current_statement).strip()) - print(f"Split statements: {statements}") - print(f"Set operation: {set_operation}") - - # Apply the filter to each SELECT statement filtered_statements = [] for stmt in statements: if is_cte: @@ -261,25 +80,19 @@ def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_nam extracted_part = match.group(1) filtered_stmt = stmt.replace(extracted_part, self._apply_filter_to_single_query(extracted_part, client_id)) filtered_statements.append(filtered_stmt) - #print(f"Filtered statement: {filtered_stmt}") else: - filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) + filtered_stmt = self._apply_filter_to_single_query(str(stmt), client_id) filtered_statements.append(filtered_stmt) - #print(f"Filtered statement: {filtered_stmt}") - # Reconstruct the query result = f" {set_operation} ".join(filtered_statements) - print(f"Final result: {result}") return result - + def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: - parts = sql_query.split(' GROUP BY ') main_query = parts[0] group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" - parsed = sqlparse.parse(main_query)[0] - tables_info = self._extract_tables_info(parsed) + tables_info = self._extract_tables_info(sqlparse.parse(main_query)[0]) filters = [] for table_info in tables_info: @@ -292,8 +105,7 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: if matching_table and matching_table not in self.filtered_tables: client_id_column = self.client_tables[matching_table] table_reference = table_alias or table_name - filters.append( - f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') self.filtered_tables.add(matching_table) if filters: @@ -307,84 +119,111 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: result = main_query return result + group_by + + def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: + possible_names = [ + f"{schema}.{table_name}" if schema else table_name, + table_name, + ] + [f"{s}.{table_name}" for s in self.schemas] - def _contains_subquery(self, parsed): - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - - set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} - joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} - case_end_keywords = {'WHEN', 'THEN', 'ELSE'} - where_keywords = {'IN', 'EXISTS', 'ANY', 'ALL', 'NOT IN'} - end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + for name in possible_names: + if self._case_insensitive_get(self.client_tables, name) is not None: + return name + return None + + def _quote_identifier(self, identifier: str) -> str: + return f'"{identifier}"' + + def _strip_quotes(self, identifier: str) -> str: + return identifier.strip('"').strip("'").strip('`') + + def _case_insensitive_get(self, dict_obj: Dict[str, str], key: str) -> Optional[str]: + if self.case_sensitive: + return dict_obj.get(key) + return next((v for k, v in dict_obj.items() if k.lower() == key.lower()), None) + + def _parse_table_identifier(self, identifier): + schema = None + alias = None + name = self._strip_quotes(str(identifier)) - i = 0 - while i < len(tokens): - token = tokens[i] + if identifier.has_alias(): + alias = self._strip_quotes(identifier.get_alias()) + name = self._strip_quotes(identifier.get_real_name()) - if token.ttype is DML and token.value.upper() == 'SELECT': - k = i + 1 - while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): # Find the index of the FROM clause - k += 1 + if '.' in name: + parts = name.split('.') + if len(parts) == 2: + schema, name = parts + name = f"{schema}.{name}" if schema else name - from_index = k - where_index = None - k = from_index + 1 + return {'name': name, 'schema': schema, 'alias': alias} - while k < len(tokens): - if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': # Find the WHERE clause if any - where_index = k - break - k += 1 - end_index = where_index if where_index else len(tokens) + def _extract_tables_info(self, parsed, tables_info=None): + if tables_info is None: + tables_info = [] - for j in range(i + 1, k): # Between SELECT and FROM block - next_token = tokens[j] - if "(" in str(next_token) and ")" in str(next_token): - if re.search(r'\bCASE\b(\s+WHEN\b.*?\bTHEN\b.*?)+(\s+ELSE\b.*?)?(?=\s+END\b)', str(next_token), re.DOTALL): - return True - else: - return True + self._extract_from_clause_tables(parsed, tables_info) + self._extract_where_clause_tables(parsed, tables_info) + self._extract_cte_tables(parsed, tables_info) - for j in range(from_index + 1, end_index): # FROM block checking for subqueries inside - next_token = str(tokens[j]).upper() - if "(" in next_token and ")" in next_token: - if any(op in next_token for op in set_operations): # Set operations - return True - elif any(join in next_token for join in joins): # Joins - return True # This condition verifies that at least one statement in JOINs is a subquery - else: - return True # Inline subquery + return tables_info - if where_index: # Procced only if WHERE exists - for j in range(where_index + 1, len(tokens)): - next_token = tokens[j] - token_str = str(next_token).upper() - if "(" in token_str and ")" in token_str: # Inline subquery - return True - if str(next_token).startswith("CASE"): # CASE END block inside WHERE - case_token_list = [t for t in TokenList(next_token).flatten() if t.ttype == Keyword] - if any(k.value.upper() in case_end_keywords for k in case_token_list): - return True - if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: # WHERE keywords - return True + def _extract_from_clause_tables(self, parsed, tables_info): + from_seen = False + for token in parsed.tokens: + if from_seen: + if isinstance(token, Identifier): + tables_info.append(self._parse_table_identifier(token)) + elif isinstance(token, IdentifierList): + for identifier in token.get_identifiers(): + if isinstance(identifier, Identifier): + tables_info.append( + self._parse_table_identifier(identifier)) + elif token.ttype is Keyword and token.value.upper() in ('WHERE', 'GROUP', 'ORDER', 'LIMIT'): + break + elif token.ttype is Keyword and token.value.upper() == 'FROM': + from_seen = True + elif token.ttype is Keyword and token.value.upper() == 'JOIN': + tables_info.append(self._parse_table_identifier( + parsed.token_next(token)[1])) - for j in range(end_index, len(tokens)): # WHERE block checking for subqueries - next_token = tokens[j] - token_str = str(next_token).upper() - if next_token.ttype == Keyword and next_token.value.upper() in end_keywords: - for m in range(j + 1, len(tokens)): - after_end_keyword_token = tokens[m] - after_end_token_str = str(after_end_keyword_token).upper() - if "(" in after_end_token_str and ")" in after_end_token_str: # Inline subquery - return True - if after_end_keyword_token.ttype == Keyword and after_end_keyword_token.value.upper() in where_keywords: # Keywords like IN, EXISTS, etc. - return True - if str(after_end_keyword_token).startswith("CASE"): # CASE END block after GROUP BY, etc. - case_token_list = [t for t in TokenList(after_end_keyword_token).flatten() if t.ttype == Keyword] - if any(k.value.upper() in case_end_keywords for k in case_token_list): - return True - i += 1 - return None + def _extract_where_clause_tables(self, parsed, tables_info): + where_clause = next( + (token for token in parsed.tokens if isinstance(token, Where)), None) + if where_clause: + for token in where_clause.tokens: + if isinstance(token, Comparison): + for item in token.tokens: + if isinstance(item, Identifier): + if '.' in item.value: + schema, name = item.value.split('.', 1) + tables_info.append( + {'name': name, 'schema': schema, 'alias': None}) + elif isinstance(item, Parenthesis): + subquery = ' '.join(str(t) + for t in item.tokens[1:-1]) + subquery_parsed = sqlparse.parse(subquery)[0] + self._extract_from_clause_tables( + subquery_parsed, tables_info) + + def _extract_cte_tables(self, parsed, tables_info): + cte_start = next((i for i, token in enumerate( + parsed.tokens) if token.ttype is Keyword and token.value.upper() == 'WITH'), None) + if cte_start is not None: + for token in parsed.tokens[cte_start:]: + if isinstance(token, sqlparse.sql.Identifier) and token.has_alias(): + cte_name = token.get_alias() + tables_info.append( + {'name': cte_name, 'schema': None, 'alias': None}) + cte_query = token.tokens[-1] + if isinstance(cte_query, sqlparse.sql.Parenthesis): + # Remove outer parentheses and parse the CTE query + cte_parsed = sqlparse.parse(str(cte_query)[1:-1])[0] + # Recursively extract tables from the CTE query + self._extract_tables_info(cte_parsed, tables_info) + elif token.ttype is DML and token.value.upper() == 'SELECT': + break def _cleanup_whitespace(self, query: str) -> str: # Split the query into lines @@ -400,199 +239,4 @@ def _cleanup_whitespace(self, query: str) -> str: r'\s*,\s*(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ', ', line) cleaned_lines.append(line) # Join the lines back together - return '\n'.join(cleaned_lines) - - def _handle_subquery(self, parsed, client_id): - result = [] - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - mainquery = [] - - for token in tokens: - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - mainquery.append(" PLACEHOLDER ") - subquery = token.tokens[0].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - alias = token.get_alias() - AS_keyword = next((t for t in token.tokens if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'AS'), None) # Checks for existence of 'AS' keyword - - if AS_keyword: - result.append(f"({filtered_subquery}) AS {alias}") - else: - result.append(f"({filtered_subquery}) {alias}") - else: - mainquery.append(str(token)) - - elif isinstance(token, Parenthesis): - mainquery.append(" PLACEHOLDER ") - subquery = token.tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - result.append(f"({filtered_subquery})") - - elif isinstance(token, Where) and 'IN' in str(parsed): - try: - filtered_where = self._handle_where_subqueries( - token, client_id) - result.append(str(filtered_where)) - except Exception as e: - result.append(str(token)) - else: - mainquery.append(str(token)) - - mainquery = ''.join(mainquery).strip() - if ' IN ' in str(parsed): - return f"{mainquery} {result[0]}" - else: - filtered_mainquery = self._apply_filter_to_single_query(mainquery, client_id) - query = filtered_mainquery.replace("PLACEHOLDER", result[0]) - return query - - def _handle_where_subqueries(self, where_clause, client_id): - if self._is_cte_query(where_clause): - cte_part = self._extract_cte_definition(where_clause) - main_query = self._extract_main_query(where_clause) - - filtered_cte = self._apply_filter_recursive(cte_part, client_id) - - if 'WHERE' not in str(main_query).upper(): - main_query = self._add_where_clause_to_main_query( - main_query, client_id) - - return f"{filtered_cte} {main_query}" - else: - new_where_tokens = [] - i = 0 - while i < len(where_clause.tokens): - token = where_clause.tokens[i] - if token.ttype is Keyword and token.value.upper() == 'IN': - next_token = where_clause.token_next(i) - if next_token and isinstance(next_token[1], Parenthesis): - subquery = next_token[1].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - filtered_subquery_str = str(filtered_subquery) - try: - new_subquery_tokens = [ - Token(Whitespace, ' '), - Token(Punctuation, '(') - ] + sqlparse.parse(filtered_subquery_str)[0].tokens + [Token(Punctuation, ')')] - new_where_tokens.extend( - [token] + new_subquery_tokens) - except Exception as e: - # Fallback to original subquery with space - new_where_tokens.extend( - [token, Token(Whitespace, ' '), next_token[1]]) - i += 2 # Skip the next token as we've handled it - else: - new_where_tokens.append(token) - elif isinstance(token, Parenthesis): - subquery = token.tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - if self._contains_subquery(sqlparse.parse(subquery_str)[0]): - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - filtered_subquery_str = str(filtered_subquery) - try: - new_subquery_tokens = sqlparse.parse( - f"({filtered_subquery_str})")[0].tokens - new_where_tokens.extend(new_subquery_tokens) - except Exception as e: - # Fallback to original subquery - new_where_tokens.append(token) - else: - new_where_tokens.append(token) - else: - new_where_tokens.append(token) - i += 1 - - # Add the client filter for the main table - try: - main_table = self._extract_main_table(where_clause) - if main_table: - main_table_filter = self._generate_client_filter( - main_table, client_id) - if main_table_filter: - filter_tokens = [ - Token(Whitespace, ' '), - Token(Keyword, 'AND'), - Token(Whitespace, ' ') - ] + sqlparse.parse(main_table_filter)[0].tokens - new_where_tokens.extend(filter_tokens) - except Exception as e: - print(f"error: {e}") - - where_clause.tokens = new_where_tokens - return where_clause - - def _generate_client_filter(self, table_name, client_id): - matching_table = self._find_matching_table(table_name) - if matching_table: - client_id_column = self.client_tables[matching_table] - return f'{self._quote_identifier(table_name)}.{self._quote_identifier(client_id_column)} = {client_id}' - return None - - def _extract_main_query(self, parsed): - main_query_tokens = [] - main_query_started = False - - for token in parsed.tokens: - if main_query_started: - main_query_tokens.append(token) - elif token.ttype is DML and token.value.upper() == 'SELECT': - main_query_started = True - main_query_tokens.append(token) - - return TokenList(main_query_tokens) - - def _extract_main_table(self, where_clause): - if where_clause.parent is None: - return None - for token in where_clause.parent.tokens: - if isinstance(token, Identifier): - return token.get_real_name() - return None - - def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: - parts = sql_query.split(' GROUP BY ') - main_query = parts[0] - - group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" - parsed = sqlparse.parse(main_query)[0] - tables_info = self._extract_tables_info(parsed) - - filters = [] - _table_ = [] - - for table_info in tables_info: - if table_info['name'] != cte_name: - table_dict = { - "name": table_info['name'], - "alias": table_info['alias'], - "schema": table_info['schema'] - } - _table_.append(table_dict) - - matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema']) - - if matching_table: - client_id_column = self.client_tables[matching_table] - table_reference = _table_[0]['alias'] or _table_[0]['name'] - - filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') - - if filters: - where_clause = " AND ".join(filters) - if 'WHERE' in main_query.upper(): - where_parts = main_query.split('WHERE', 1) - result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}" - else: - result = f"{main_query} WHERE {where_clause}" - else: - result = main_query - - return result + group_by \ No newline at end of file + return '\n'.join(cleaned_lines) \ No newline at end of file From 11edff7adaf6cf53eb1a7f91766a5101b4b55773 Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Wed, 25 Sep 2024 01:24:59 +0530 Subject: [PATCH 11/15] _apply_filter_to_single_CTE_query method was accidently removed in last commit - added back --- src/dataneuron/core/sql_query_filter.py | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index d4d4087..2432bc8 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -225,6 +225,46 @@ def _extract_cte_tables(self, parsed, tables_info): elif token.ttype is DML and token.value.upper() == 'SELECT': break + def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: + parts = sql_query.split(' GROUP BY ') + main_query = parts[0] + + group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" + parsed = sqlparse.parse(main_query)[0] + tables_info = self._extract_tables_info(parsed) + + filters = [] + _table_ = [] + + for table_info in tables_info: + if table_info['name'] != cte_name: + table_dict = { + "name": table_info['name'], + "alias": table_info['alias'], + "schema": table_info['schema'] + } + _table_.append(table_dict) + + matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema']) + + if matching_table: + client_id_column = self.client_tables[matching_table] + table_reference = _table_[0]['alias'] or _table_[0]['name'] + + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') + + if filters: + where_clause = " AND ".join(filters) + if 'WHERE' in main_query.upper(): + where_parts = main_query.split('WHERE', 1) + result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}" + else: + result = f"{main_query} WHERE {where_clause}" + else: + result = main_query + + return result + group_by + def _cleanup_whitespace(self, query: str) -> str: # Split the query into lines lines = query.split('\n') From b38e278499fdc2a9b56eb3d0c49ba5aaae002922 Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Fri, 27 Sep 2024 01:40:22 +0530 Subject: [PATCH 12/15] Final logic for detecting and handling subqueries, fixed circular import errors with importlib, _cleanup_whitespace method is now in a separate module --- .../core/nlp_helpers/is_subquery.py | 193 +++++++----- .../core/nlp_helpers/query_cleanup.py | 17 ++ .../core/nlp_helpers/subquery_handler.py | 282 +++++++++--------- src/dataneuron/core/sql_query_filter.py | 35 +-- 4 files changed, 293 insertions(+), 234 deletions(-) create mode 100644 src/dataneuron/core/nlp_helpers/query_cleanup.py diff --git a/src/dataneuron/core/nlp_helpers/is_subquery.py b/src/dataneuron/core/nlp_helpers/is_subquery.py index 83158f6..267da08 100644 --- a/src/dataneuron/core/nlp_helpers/is_subquery.py +++ b/src/dataneuron/core/nlp_helpers/is_subquery.py @@ -1,73 +1,132 @@ -from sqlparse.tokens import DML, Keyword, Whitespace -from sqlparse.sql import TokenList, Parenthesis, Function import re +from sqlparse.sql import Token +from sqlparse.tokens import DML, Keyword, Whitespace, Newline +from query_cleanup import _cleanup_whitespace def _contains_subquery(parsed): tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} - case_end_keywords = {'WHEN', 'THEN', 'ELSE'} - where_keywords = {'IN', 'EXISTS', 'ANY', 'ALL', 'NOT IN'} - end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} - - def is_subquery(token): - if isinstance(token, Parenthesis): - inner_tokens = token.tokens - for inner_token in inner_tokens: - if inner_token.ttype is not Whitespace: - return inner_token.ttype is DML and inner_token.value.upper() == 'SELECT' - return False - - i = 0 - while i < len(tokens): - token = tokens[i] - - if token.ttype is DML and token.value.upper() == 'SELECT': - k = i + 1 - while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): - k += 1 - - from_index = k - where_index = None - k = from_index + 1 - - while k < len(tokens): - if 'WHERE' in str(tokens[k]): - where_index = k - break - k += 1 - end_index = where_index if where_index else len(tokens) - - for j in range(i + 1, from_index): # Between SELECT and FROM block - next_token = tokens[j] - if is_subquery(next_token): - return True - if isinstance(next_token, Function): - if re.search(r'\bCASE\b(\s+WHEN\b.?\bTHEN\b.?)+(\s+ELSE\b.*?)?(?=\s+END\b)', str(next_token), re.DOTALL): - # Check for subquery within CASE statement - if any(is_subquery(t) for t in next_token.tokens): - return True - - for j in range(from_index + 1, end_index): # FROM block checking for subqueries inside - next_token = tokens[j] - if is_subquery(next_token): - return True - if isinstance(next_token, Function): - if any(op in next_token.value.upper() for op in set_operations): # Set operations - return True - if any(join in str(next_token).upper() for join in joins): # Joins - if j+1 < len(tokens) and is_subquery(tokens[j+1]): - return True - - if where_index: # Proceed only if WHERE exists - for j in range(where_index + 1, len(tokens)): - next_token = tokens[j] - if next_token.ttype == Keyword and next_token.value.upper() in where_keywords: - if j+1 < len(tokens) and is_subquery(tokens[j+1]): - return True - elif is_subquery(next_token): - return True - - i += 1 - return False \ No newline at end of file + + where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} + where_keyword_pattern = '|'.join(where_keywords) + + select_index = None + from_index = None + where_index = None + end_index = None + + + select_block = [] + from_block = [] + join_statement = [] + join_found = False + where_block = [] + results = [] + + def keyword_index(tokens): + nonlocal select_index, from_index, where_index + i = 0 + while i < len(tokens): + token = tokens[i] + + if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': + select_index = i + k = i + 1 + + while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + k = from_index + 1 + + while k < len(tokens): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]): + where_index = k + break + k += 1 + + i += 1 + + keyword_index(tokens) + from_end = where_index if where_index is not None else len(tokens) + + for j in range(select_index + 1, from_index): # Between SELECT and FROM block + select_block.append(_cleanup_whitespace(str(tokens[j]))) + + select_elements = ' '.join(select_block).strip().split(',') # Split by commas to handle multiple elements in the SELECT block + + for element in select_elements: + element = element.replace('\n', ' ').strip() # Clean up any extra whitespace + + if re.search(r'\bCASE\b((\s+WHEN\b.*?\bTHEN\b.*?)+)(\s+ELSE\b.*)?(?=\s+END\b)', element, re.DOTALL): + results.append("CASE Block exists in SELECT block - Checking if it has subquery in any of WHEN, THEN or ELSE blocks") + + for match in re.findall(r'\bWHEN\b.*?\bTHEN\b.*?\bELSE\b.*?(?=\bWHEN\b|\bELSE\b|\bEND\b)', element, re.DOTALL): #Split them into WHEN, THEN and ELSE blocks: # Check for subquery inside WHEN THEN + if re.search(r'\(.*?\bSELECT\b.*?\)', match, re.DOTALL): + results.append("Subquery exists inside CASE WHEN THEN ELSE block") + + elif '(' in element and ')' in element: # Find if any element has parenthesis + results.append("Inline element has parenthesis inside SELECT block - Checking if it has subquery") + if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): + results.append("Inline Subquery exists inside SELECT block") + + for j in range(from_index + 1, from_end): # Between FROM and WHERE (or) end of query + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + from_block.append(tokens[j]) + + for i, element in enumerate(from_block): + if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: # JOINs + join_found = True + if i == 1: + join_statement.append(str(from_block[i - 1])) + join_statement.append(str(from_block[i + 1])) + elif i > 1: + join_statement.append(str(from_block[i + 1])) + + elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): + results.append("Outer parentheses found inside FROM block - Checking if is an inline or contains set operation") + + if re.match(r'\(\s*SELECT.*UNION.*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + results.append("(SELECT ... UNION .. SELECT...) - Contains set operation - Not a subquery inside FROM block") + elif re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + results.append("( (SELECT ...) UNION .. (SELECT...) ) - Contains set operation - Subquery found inside FROM block") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + results.append("Inline subquery inside FROM block") + + if join_found: + results.append("JOIN operation found inside FROM - Checking if has subquery") + for stmt in join_statement: + join_statement_str = _cleanup_whitespace(str(stmt)) + if "(" in join_statement_str and ")" in join_statement_str: + results.append("Parenthesis found inside JOIN - Checking if is a subquery") + + if re.match(r'\(\s*SELECT.*UNION.*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("(SELECT ... UNION .. SELECT...) - Not a subquery inside JOIN") + elif re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("( (SELECT ...) UNION .. (SELECT...) ) - Subquery inside JOIN") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("Inline subquery inside JOIN") + + if where_index: # Between WHERE and end of query (or) End_keywords + for j in range(where_index, len(tokens)): + where_block.append(_cleanup_whitespace(str(tokens[j]).strip('WHERE '))) + + for i in where_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): # Splits into multiple statements if AND exists, else selects the single statement + clause = clause.strip() + + # Check for the presence of any special keyword like IN, NOT IN, EXISTS, ALL, ANY + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + found_keyword = re.search(fr'\b({where_keyword_pattern})\b', clause).group() + results.append(f"Subquery with special keyword found in WHERE block: {found_keyword} \n") + + # Check for subquery using a SELECT statement in parentheses + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + results.append("Inline subquery found in WHERE block \n") + + if len(results) > 1: + return True + else: + return False \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/query_cleanup.py b/src/dataneuron/core/nlp_helpers/query_cleanup.py new file mode 100644 index 0000000..dedc2dd --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/query_cleanup.py @@ -0,0 +1,17 @@ +import re + +def _cleanup_whitespace(query: str) -> str: + # Split the query into lines + lines = query.split('\n') + cleaned_lines = [] + for line in lines: + # Remove leading/trailing whitespace from each line + line = line.strip() + # Replace multiple spaces with a single space, but not in quoted strings + line = re.sub(r'\s+(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ' ', line) + # Ensure single space after commas, but not in quoted strings + line = re.sub( + r'\s*,\s*(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ', ', line) + cleaned_lines.append(line) + # Join the lines back together + return '\n'.join(cleaned_lines) \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py index 66ef6ec..81f9a5a 100644 --- a/src/dataneuron/core/nlp_helpers/subquery_handler.py +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -1,45 +1,63 @@ import sqlparse -from sqlparse.sql import TokenList -from sqlparse.tokens import Keyword, DML -from sql_query_filter import SQLQueryFilter +from sqlparse.sql import Token +from sqlparse.tokens import Keyword, DML, Whitespace, Newline import re +from sql_query_filter import SQLQueryFilter +from query_cleanup import _cleanup_whitespace def _handle_subquery(parsed, client_id): tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - i = 0 - while i < len(tokens): - token = tokens[i] - - if token.ttype is DML and token.value.upper() == 'SELECT': - select_start = i + 1 - k = select_start - - while k < len(tokens) and not (tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): # Find the index of the FROM clause - k += 1 - - from_index = k - where_index = None - k = from_index + 1 - - while k < len(tokens): # Find the WHERE clause if any - if tokens[k].ttype == Keyword and tokens[k].value.upper() == 'WHERE': - where_index = k - break - k += 1 - - end_index = len(tokens) - - SELECT_block = TokenList(tokens[select_start:from_index]) - FROM_block = TokenList(tokens[from_index + 1:where_index]) if where_index else TokenList(tokens[from_index + 1:end_index]) - WHERE_block = TokenList(tokens[where_index + 1:end_index]) if where_index else None - - i = end_index # Move the index to the end of the processed part - else: + + select_index = None + from_index = None + where_index = None + end_index = None + + select_block = [] + from_block = [] + where_block = [] + + def keyword_index(tokens): + nonlocal select_index, from_index, where_index + i = 0 + while i < len(tokens): + token = tokens[i] + + if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': + select_index = i + k = i + 1 + + while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + k = from_index + 1 + + while k < len(tokens): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]): + where_index = k + break + k += 1 + i += 1 - SELECT_dict = SELECT_subquery(parsed, client_id) - FROM_dict = FROM_subquery(parsed, client_id) - WHERE_dict = WHERE_subquery(parsed, client_id) + keyword_index(tokens) + from_end = where_index if where_index is not None else len(tokens) + + for j in range(select_index + 1, from_index): # Between SELECT and FROM block + select_block.append(_cleanup_whitespace(str(tokens[j]))) + + for j in range(from_index + 1, from_end): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + from_block.append(tokens[j]) + + if where_index: + for j in range(where_index, len(tokens)): + where_block.append(_cleanup_whitespace(str(tokens[j]).strip('WHERE '))) + WHERE_dict = WHERE_subquery(parsed, client_id) + + SELECT_dict = SELECT_subquery(select_block, client_id) + FROM_dict = FROM_subquery(from_index, client_id) subquery_dict = { "subqueries": SELECT_dict['subquery_list'] + FROM_dict['subquery_list'] + WHERE_dict['subquery_list'], @@ -61,17 +79,30 @@ def _handle_subquery(parsed, client_id): for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): filtered_mainquery = filtered_mainquery.replace(placeholder, f"({str(filtered_subquery)})") - return filtered_mainquery + print(filtered_mainquery) def SELECT_subquery(SELECT_block, client_id): - subqueries = re.findall(r'\(([^()]+(?:\([^()]\))[^()]*)\)', str(SELECT_block)) + + select_elements = ' '.join(SELECT_block).strip().split(',') # Split by commas to handle multiple elements in the SELECT block filtered_dict = { - 'subquery_list': subqueries, + 'subquery_list': [], 'filtered_subquery': [], 'placeholder_value': [] } + for i, element in enumerate(select_elements): + element = element.replace('\n', ' ').strip() # Clean up any extra whitespace + + if re.search(r'\bCASE\b((\s+WHEN\b.*?\bTHEN\b.*?)+)(\s+ELSE\b.*)?(?=\s+END\b)', element, re.DOTALL): + for match in re.findall(r'\bWHEN\b.*?\bTHEN\b.*?\bELSE\b.*?(?=\bWHEN\b|\bELSE\b|\bEND\b)', element, re.DOTALL): #Split them into WHEN, THEN and ELSE blocks: # Check for subquery inside WHEN THEN + if re.search(r'\(.*?\bSELECT\b.*?\)', match, re.DOTALL): + filtered_dict['subquery_list'].append(match) + + elif '(' in element and ')' in element: # Find if any element has parenthesis + if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): + filtered_dict['subquery_list'].append(element) + for i, subquery in enumerate(filtered_dict['subquery_list']): # Apply filters to extracted subqueries placeholder = f"" filtered_subquery = SQLQueryFilter._apply_filter_to_single_query(subquery, client_id) @@ -80,8 +111,9 @@ def SELECT_subquery(SELECT_block, client_id): return filtered_dict - def FROM_subquery(FROM_block, client_id): + join_found = False + join_statement = [] joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} subquery_dict = { @@ -90,125 +122,89 @@ def FROM_subquery(FROM_block, client_id): "set operations": [], } - FROM_block_tokens = [token for token in FROM_block.tokens if not token.is_whitespace] - - for i, token in enumerate(FROM_block_tokens): - if token.ttype is Keyword and token.value.upper() in joins: # JOINs operations - if i > 0 and "(" in str(FROM_block_tokens[i-1]) and ")" in str(FROM_block_tokens[i-1]): # Select only subqueries - subquery = FROM_block_tokens[i-1].value - if subquery not in subquery_dict['join subquery']: - subquery_dict['join subquery'].append(subquery) - - if i < len(FROM_block_tokens) - 1 and "(" in str(FROM_block_tokens[i+1]) and ")" in str(FROM_block_tokens[i+1]): - subquery = FROM_block_tokens[i+1].value - if subquery not in subquery_dict['join subquery']: - subquery_dict['join subquery'].append(subquery) - i += 1 - - i = 0 - while i < len(FROM_block.tokens): # SET operation - if FROM_block.tokens[i].ttype is Keyword and FROM_block.tokens[i].value.upper() in set_operations: - subquery_dict['set operations'].append(str(FROM_block)) - break - i += 1 - - filtered_dict = { + for i, element in enumerate(FROM_block): + if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: # JOINs + join_found = True + if i == 1: + join_statement.append(str(FROM_block[i - 1])) + join_statement.append(str(FROM_block[i + 1])) + elif i > 1: + join_statement.append(str(FROM_block[i + 1])) + + elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): + if re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + subquery_dict['set operations'].append(str(element)) + elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + subquery_dict['inline subquery'].append(str(element)) + if join_found: + for stmt in join_statement: + join_statement_str = _cleanup_whitespace(str(stmt)) + if "(" in join_statement_str and ")" in join_statement_str: + if re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_dict['set operations'].append(join_statement_str) + elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_dict['join subquery'].append(join_statement_str) + + non_setop_filtered_dict = { 'subquery_list': subquery_dict['inline subquery'] + subquery_dict['join subquery'], 'filtered_subquery': [], 'placeholder_value': [] } + setop_filtered_dict = { + 'subquery_list': subquery_dict['set operations'], + 'filtered_subquery': [], + 'placeholder_value': [] + } - for i in range(len(filtered_dict['subquery_list'])): - placeholder = f"" + for nsod in range(len(non_setop_filtered_dict['subquery_list'])): + placeholder = f"" filtered_subquery = SQLQueryFilter._apply_filter_recursive( - sqlparse.parse(filtered_dict['subquery_list'][i])[0], client_id - ) - filtered_dict['placeholder_value'].append(placeholder) - filtered_dict['filtered_subquery'].append(filtered_subquery) + sqlparse.parse(non_setop_filtered_dict['subquery_list'][nsod])[0], client_id) + non_setop_filtered_dict['placeholder_value'].append(placeholder) + non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) + + for sod in range(len(setop_filtered_dict['set operations'])): + placeholder = f"" + non_setop_filtered_dict['subquery_list'].append(subquery_dict['set operations'][sod]) + non_setop_filtered_dict['placeholder_value'].append(placeholder) + filtered_subquery = SQLQueryFilter._handle_set_operation( + sqlparse.parse(subquery_dict['set operations'][sod])[0], client_id) + non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) - for j in range(len(subquery_dict['set operations'])): - placeholder = f"" - filtered_dict['subquery_list'].append(subquery_dict['set operations'][j]) - filtered_dict['placeholder_value'].append(placeholder) + filtered_dict = { + 'subquery_list': non_setop_filtered_dict['subquery_list'] + setop_filtered_dict['subquery_list'], + 'filtered_subquery': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'], + 'placeholder_value': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'] + } - filtered_subquery = SQLQueryFilter._handle_set_operation( - sqlparse.parse(subquery_dict['set operations'][j])[0], client_id - ) - filtered_dict['filtered_subquery'].append(filtered_subquery) return filtered_dict -def WHERE_subquery(parsed, client_id): - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - i = 0 - - subquery_dict = { - "in_subquery": [], - "not_in_subquery": [], - "exists_subquery": [], - "not_exists_subquery": [], - "any_subquery": [], - "all_subquery": [], - "inline subquery": [], - } +def WHERE_subquery(WHERE_block, client_id): - def subquery_extractor(next_token): - for t in next_token[1]: - if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'IN': - next_token_in = next_token[1].token_next(next_token[1].token_index(t)) - if "(" in str(next_token_in) and ")" in str(next_token_in): - subquery_dict['in_subquery'].append(str(TokenList(next_token_in[0][1:-1]))) - - elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'NOT' and next_token[1].token_next(next_token[1].token_index(t))[1].value.upper() == 'IN': - next_token_not_in = next_token[1].token_next(next_token[1].token_index(t) + 1) - if "(" in str(next_token_not_in) and ")" in str(next_token_not_in): - subquery_dict['not_in_subquery'].append(str(TokenList(next_token_not_in[0][1:-1]))) - - elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'EXISTS': - next_token_exists = next_token[1].token_next(next_token[1].token_index(t)) - if "(" in str(next_token_exists) and ")" in str(next_token_exists): - subquery_dict['exists_subquery'].append(str(TokenList(next_token_exists[0][1:-1]))) - - elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'NOT' and next_token[1].token_next(next_token[1].token_index(t))[1].value.upper() == 'EXISTS': - next_token_not_exists = next_token[1].token_next(next_token[1].token_index(t) + 1) - if "(" in str(next_token_not_exists) and ")" in str(next_token_not_exists): - subquery_dict['not_exists_subquery'].append(str(TokenList(next_token_not_exists[0][1:-1]))) - - elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'ANY': - next_token_any = next_token[1].token_next(next_token[1].token_index(t)) - if "(" in str(next_token_any) and ")" in str(next_token_any): - subquery_dict['any_subquery'].append(str(TokenList(next_token_any[0][1:-1]))) - - elif t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'ALL': - next_token_all = next_token[1].token_next(next_token[1].token_index(t)) - if "(" in str(next_token_all) and ")" in str(next_token_all): - subquery_dict['all_subquery'].append(str(TokenList(next_token_all[0][1:-1]))) - - elif "(" in str(t) and ")" in str(t): - subquery_dict['inline subquery'].append(str(TokenList(t[0][1:-1]))) - - else: - SQLQueryFilter._apply_filter_to_single_query(str(parsed), client_id) - - while i < len(tokens): - token = parsed.tokens[i] - if token.ttype == sqlparse.tokens.Keyword and token.value.upper() == 'WHERE': - next_token = parsed.token_next(i) - subquery_extractor(next_token) - i += 1 - + where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} + where_keyword_pattern = '|'.join(where_keywords) filtered_dict = { - 'subquery_list': subquery_dict['in_subquery'] + subquery_dict['not_in_subquery'] + - subquery_dict['exists_subquery'] + subquery_dict['not_exists_subquery'] + - subquery_dict['any_subquery'] + subquery_dict['all_subquery'] + - subquery_dict['inline subquery'], - 'filtered_subquery': [], + 'subquery_list': [], + 'filtered_subquery': [], 'placeholder_value': [] - } + } + + for i in WHERE_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): # Splits into multiple statements if AND exists, else selects the single statement + clause = clause.strip() + + # Check for the presence of any special keyword like IN, NOT IN, EXISTS, ALL, ANY + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) + + # Check for subquery using a SELECT statement in parentheses + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) - for i in range(len(filtered_dict['subquery_list'])): - placeholder = f"" - filtered_subquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(filtered_dict['subquery'][i])[0], client_id) + for j in range(len(filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(filtered_dict['subquery_list'][j])[0], client_id) filtered_dict['placeholder_value'].append(placeholder) filtered_dict['filtered_subquery'].append(filtered_subquery) diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 2432bc8..b9c1812 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -1,13 +1,13 @@ import re import sqlparse from sqlparse.sql import IdentifierList, Identifier, Token, TokenList, Parenthesis, Where, Comparison -from sqlparse.tokens import Keyword, DML, Name, Whitespace, Punctuation +from sqlparse.tokens import Keyword, DML from typing import List, Dict, Optional +from .nlp_helpers.query_cleanup import _cleanup_whitespace from .nlp_helpers.cte_handler import handle_cte_query from .nlp_helpers.is_cte import is_cte_query from .nlp_helpers.is_subquery import _contains_subquery -from .nlp_helpers.subquery_handler import _handle_subquery - +import importlib class SQLQueryFilter: def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], case_sensitive: bool = False): @@ -15,12 +15,14 @@ def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], self.schemas = schemas self.case_sensitive = case_sensitive self.filtered_tables = set() - self._is_cte_query = is_cte_query + self._cleanup_whitespace = _cleanup_whitespace + + self._handle_subquery = importlib.import_module('subquery_handler') # Fixing circular import error def apply_client_filter(self, sql_query: str, client_id: int) -> str: self.filtered_tables = set() parsed = sqlparse.parse(sql_query)[0] - is_cte = self._is_cte_query(parsed) + is_cte = is_cte_query(parsed) if is_cte: return handle_cte_query(parsed, self._apply_filter_recursive, client_id) @@ -29,7 +31,8 @@ def apply_client_filter(self, sql_query: str, client_id: int) -> str: return self._cleanup_whitespace(str(result)) def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): - if self._is_cte_query(parsed): + + if is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) else: for token in parsed.tokens: @@ -37,7 +40,7 @@ def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if self._contains_set_operation(parsed) and not _contains_subquery(parsed): return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) elif _contains_subquery(parsed): - return _handle_subquery(parsed, client_id) + return self._handle_subquery(parsed, client_id) else: return self._apply_filter_to_single_query(str(parsed), client_id) @@ -263,20 +266,4 @@ def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_ else: result = main_query - return result + group_by - - def _cleanup_whitespace(self, query: str) -> str: - # Split the query into lines - lines = query.split('\n') - cleaned_lines = [] - for line in lines: - # Remove leading/trailing whitespace from each line - line = line.strip() - # Replace multiple spaces with a single space, but not in quoted strings - line = re.sub(r'\s+(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ' ', line) - # Ensure single space after commas, but not in quoted strings - line = re.sub( - r'\s*,\s*(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ', ', line) - cleaned_lines.append(line) - # Join the lines back together - return '\n'.join(cleaned_lines) \ No newline at end of file + return result + group_by \ No newline at end of file From afafbff8ee967f38f4b6f4908f6ded2987656631 Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Mon, 30 Sep 2024 23:37:29 +0530 Subject: [PATCH 13/15] Added end_keyword subquery detection and handling --- .../core/nlp_helpers/is_subquery.py | 150 +++--- .../core/nlp_helpers/subquery_handler.py | 459 +++++++++++------- src/dataneuron/core/sql_query_filter.py | 32 +- tests/core/test_sql_query_filter.py | 278 ----------- 4 files changed, 378 insertions(+), 541 deletions(-) diff --git a/src/dataneuron/core/nlp_helpers/is_subquery.py b/src/dataneuron/core/nlp_helpers/is_subquery.py index 267da08..1653cd4 100644 --- a/src/dataneuron/core/nlp_helpers/is_subquery.py +++ b/src/dataneuron/core/nlp_helpers/is_subquery.py @@ -5,10 +5,8 @@ def _contains_subquery(parsed): tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - - set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} - where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} where_keyword_pattern = '|'.join(where_keywords) @@ -17,116 +15,134 @@ def _contains_subquery(parsed): where_index = None end_index = None - select_block = [] from_block = [] - join_statement = [] - join_found = False where_block = [] + end_keywords_block = [] results = [] + join_statement = [] + join_found = False - def keyword_index(tokens): - nonlocal select_index, from_index, where_index - i = 0 - while i < len(tokens): - token = tokens[i] - - if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': - select_index = i - k = i + 1 - - while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): - k += 1 - - from_index = k - k = from_index + 1 - - while k < len(tokens): - if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]): - where_index = k - break - k += 1 - - i += 1 - - keyword_index(tokens) - from_end = where_index if where_index is not None else len(tokens) + i = 0 + while i < len(tokens): + token = tokens[i] + + if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': + select_index = i + k = i + 1 + while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + k = from_index + 1 + while k < len(tokens): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ + re.match(r'\(\s*SELECT.*?\bWHERE\b.*?\)', str(tokens[k])): + where_index = k + elif isinstance(tokens[k], Token) and str(tokens[k]) in end_keywords: + end_index = k + break + + k += 1 + i += 1 + + where_end = end_index if end_index else len(tokens) + from_end = min( + index for index in [where_index, end_index] if index is not None) if any([where_index, end_index]) \ + else len(tokens) for j in range(select_index + 1, from_index): # Between SELECT and FROM block select_block.append(_cleanup_whitespace(str(tokens[j]))) select_elements = ' '.join(select_block).strip().split(',') # Split by commas to handle multiple elements in the SELECT block - for element in select_elements: element = element.replace('\n', ' ').strip() # Clean up any extra whitespace if re.search(r'\bCASE\b((\s+WHEN\b.*?\bTHEN\b.*?)+)(\s+ELSE\b.*)?(?=\s+END\b)', element, re.DOTALL): - results.append("CASE Block exists in SELECT block - Checking if it has subquery in any of WHEN, THEN or ELSE blocks") for match in re.findall(r'\bWHEN\b.*?\bTHEN\b.*?\bELSE\b.*?(?=\bWHEN\b|\bELSE\b|\bEND\b)', element, re.DOTALL): #Split them into WHEN, THEN and ELSE blocks: # Check for subquery inside WHEN THEN if re.search(r'\(.*?\bSELECT\b.*?\)', match, re.DOTALL): results.append("Subquery exists inside CASE WHEN THEN ELSE block") elif '(' in element and ')' in element: # Find if any element has parenthesis - results.append("Inline element has parenthesis inside SELECT block - Checking if it has subquery") if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): results.append("Inline Subquery exists inside SELECT block") - for j in range(from_index + 1, from_end): # Between FROM and WHERE (or) end of query + + for j in range(from_index + 1, from_end): if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: from_block.append(tokens[j]) for i, element in enumerate(from_block): - if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: # JOINs + if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: join_found = True + if i == 1: join_statement.append(str(from_block[i - 1])) join_statement.append(str(from_block[i + 1])) elif i > 1: join_statement.append(str(from_block[i + 1])) - - elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): - results.append("Outer parentheses found inside FROM block - Checking if is an inline or contains set operation") - if re.match(r'\(\s*SELECT.*UNION.*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - results.append("(SELECT ... UNION .. SELECT...) - Contains set operation - Not a subquery inside FROM block") - elif re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - results.append("( (SELECT ...) UNION .. (SELECT...) ) - Contains set operation - Subquery found inside FROM block") + elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', str(element), re.IGNORECASE | re.DOTALL): + results.append("Contains set operation - Subquery found inside FROM block") elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - results.append("Inline subquery inside FROM block") - + results.append("Inline subquery inside FROM block") + if join_found: - results.append("JOIN operation found inside FROM - Checking if has subquery") for stmt in join_statement: join_statement_str = _cleanup_whitespace(str(stmt)) - if "(" in join_statement_str and ")" in join_statement_str: - results.append("Parenthesis found inside JOIN - Checking if is a subquery") - - if re.match(r'\(\s*SELECT.*UNION.*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): - results.append("(SELECT ... UNION .. SELECT...) - Not a subquery inside JOIN") - elif re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): - results.append("( (SELECT ...) UNION .. (SELECT...) ) - Subquery inside JOIN") + if re.findall(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("Set operation - Subquery inside JOIN") elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): results.append("Inline subquery inside JOIN") - if where_index: # Between WHERE and end of query (or) End_keywords - for j in range(where_index, len(tokens)): + if where_index: + for j in range(where_index, where_end): where_block.append(_cleanup_whitespace(str(tokens[j]).strip('WHERE '))) - for i in where_block: - for clause in re.split(r'\bAND\b(?![^()]*\))', i): # Splits into multiple statements if AND exists, else selects the single statement - clause = clause.strip() + for i in where_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): + clause = clause.strip() - # Check for the presence of any special keyword like IN, NOT IN, EXISTS, ALL, ANY - if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): - found_keyword = re.search(fr'\b({where_keyword_pattern})\b', clause).group() - results.append(f"Subquery with special keyword found in WHERE block: {found_keyword} \n") + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + found_keyword = re.search(fr'\b({where_keyword_pattern})\b', clause).group() + results.append(f"Subquery with special keyword found in WHERE block: {found_keyword} \n") + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + results.append("Inline subquery found in WHERE block \n") - # Check for subquery using a SELECT statement in parentheses - elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): - results.append("Inline subquery found in WHERE block \n") + if end_index: + for j in range(end_index, len(tokens)): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + end_keywords_block.append(_cleanup_whitespace(str(tokens[j]))) - if len(results) > 1: + endsubquery_block = [] + count = 0 + indices = [] + + for index, token in enumerate(end_keywords_block): + if str(token).upper() in end_keywords: + count += 1 + indices.append(index) + + if count >= 1: # If there is at least one end keyword + for i in range(len(indices)): + start_idx = indices[i] # Start and end indices of each block + if i < len(indices) - 1: + end_idx = indices[i + 1] # Until the next keyword + else: + end_idx = len(end_keywords_block) # Until the end of the block + + # Extract the block between start_idx and end_idx + endsubquery_block = end_keywords_block[start_idx:end_idx] + endsubquery_block_str = ' '.join(endsubquery_block) + + if re.search(r'\((SELECT [\s\S]*?)\)', str(endsubquery_block_str), re.IGNORECASE): + if re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', str(endsubquery_block_str), re.IGNORECASE).group(1): + results.append("Subquery in END keywords") + + if len(results) >= 1: return True else: return False \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py index 81f9a5a..27e6fdf 100644 --- a/src/dataneuron/core/nlp_helpers/subquery_handler.py +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -2,23 +2,223 @@ from sqlparse.sql import Token from sqlparse.tokens import Keyword, DML, Whitespace, Newline import re -from sql_query_filter import SQLQueryFilter from query_cleanup import _cleanup_whitespace -def _handle_subquery(parsed, client_id): - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - select_index = None - from_index = None - where_index = None - end_index = None +class SubqueryHandler: + def __init__(self, query_filter=None, setop_query_filter=None): + self.SQLQueryFilter = query_filter + self.SetOP_QueryFilter = setop_query_filter + self._cleanup_whitespace = _cleanup_whitespace + self.client_id = 1 + + + def SELECT_subquery(self, SELECT_block): + select_elements = ' '.join(SELECT_block).strip().split(',') + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for element in select_elements: + element = element.replace('\n', ' ').strip() + + # Detect CASE WHEN THEN ELSE + case_match = re.search(r'\bCASE\b(.*?\bEND\b)', element, re.DOTALL) + if case_match: + case_block = case_match.group(1) + when_then_else_blocks = re.findall(r'\bWHEN\b(.*?)\bTHEN\b(.*?)(?=\bWHEN\b|\bELSE\b|\bEND\b)', case_block, re.DOTALL) + else_clause = re.search(r'\bELSE\b(.*?)(?=\bEND\b)', case_block, re.DOTALL) + + # Process WHEN-THEN pairs + for when, then in when_then_else_blocks: + if re.search(r'\(.*?\bSELECT\b.*?\)', when, re.DOTALL): # Check if WHEN has a subquery + filtered_dict['subquery_list'].append(when) + if re.search(r'\(.*?\bSELECT\b.*?\)', then, re.DOTALL): # Check if THEN has a subquery + filtered_dict['subquery_list'].append(then) + + # Process ELSE clause if exists + if else_clause and re.search(r'\(.*?\bSELECT\b.*?\)', else_clause.group(1), re.DOTALL): + filtered_dict['subquery_list'].append(else_clause.group(1)) + + # Handle simple subqueries outside CASE block + elif '(' in element and ')' in element: + if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): + filtered_dict['subquery_list'].append(element) + + # Create placeholders and filter subqueries + for i, subquery in enumerate(filtered_dict['subquery_list']): + placeholder = f"" + filtered_subquery = self.SQLQueryFilter( + sqlparse.parse( + re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', subquery).group(1) + )[0], + self.client_id + ) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + + def FROM_subquery(self, FROM_block): + join_statement = [] + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} + + subquery_dict = { + "inline subquery": [], + "join subquery": [], + "set operations": [], + } + + join_found = False + for element in FROM_block: # Separate block to find at least one occurence of JOIN + if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: + join_found = True + break + + for i, element in enumerate(FROM_block): + if join_found: + if i == 1 and isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: + join_statement.append(str(FROM_block[i - 1])) + join_statement.append(str(FROM_block[i + 1])) + elif i > 1 and isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: + join_statement.append(str(FROM_block[i + 1])) + + elif not join_found: + if re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', str(element), re.IGNORECASE | re.DOTALL): + subquery_dict["set operations"].append(f"({str(element)})") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + subquery_dict['inline subquery'].append(str(element)) + + for stmt in join_statement: + join_statement_str = self._cleanup_whitespace(str(stmt)) + if re.findall(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_dict["set operations"].append(f"({join_statement_str})") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_dict['join subquery'].append(join_statement_str) + + + non_setop_filtered_dict = { + 'subquery_list': subquery_dict['inline subquery'] + subquery_dict['join subquery'], + 'filtered_subquery': [], + 'placeholder_value': [] + } + setop_filtered_dict = { + 'subquery_list': subquery_dict['set operations'], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for nsod in range(len(non_setop_filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = self.SQLQueryFilter( sqlparse.parse(re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', non_setop_filtered_dict['subquery_list'][nsod]).group(1) )[0], self.client_id ) + + non_setop_filtered_dict['placeholder_value'].append(placeholder) + non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) + + for sod in range(len(setop_filtered_dict["subquery_list"])): + placeholder = f"" + non_setop_filtered_dict['subquery_list'].append(subquery_dict['set operations'][sod]) + non_setop_filtered_dict['placeholder_value'].append(placeholder) + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(re.search(r'^\((.*)\)(\s+AS\s+\w+)?;?$', subquery_dict['set operations'][sod]).group(1))[0], self.client_id) + non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) + + filtered_dict = { + 'subquery_list': non_setop_filtered_dict['subquery_list'] + setop_filtered_dict['subquery_list'], + 'filtered_subquery': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'], + 'placeholder_value': non_setop_filtered_dict['placeholder_value'] + setop_filtered_dict['placeholder_value'] + } + + return filtered_dict + + + def WHERE_subquery(self, WHERE_block): + where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} + where_keyword_pattern = '|'.join(where_keywords) + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for i in WHERE_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): + clause = clause.strip() + + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) + + for j in range(len(filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = self.SQLQueryFilter( sqlparse.parse( re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', (filtered_dict['subquery_list'][j])).group(1) )[0], self.client_id ) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + def END_subqueries(self, end_keywords_block): + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + + # Dictionary to hold the result + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + endsubquery_block = [] + count = 0 + indices = [] + + for index, token in enumerate(end_keywords_block): + if str(token).upper() in end_keywords: + count += 1 + indices.append(index) + + if count >= 1: # If there is at least one end keyword + for i in range(len(indices)): + start_idx = indices[i] # Start and end indices of each block + if i < len(indices) - 1: + end_idx = indices[i + 1] # Until the next keyword + else: + end_idx = len(end_keywords_block) # Until the end of the block + + # Extract the block between start_idx and end_idx + endsubquery_block = end_keywords_block[start_idx:end_idx] + endsubquery_block_str = ' '.join(endsubquery_block) + + if re.search(r'\((SELECT [\s\S]*?)\)', str(endsubquery_block_str), re.IGNORECASE): + subquery_match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', str(endsubquery_block_str), re.IGNORECASE).group(1) + print(subquery_match) + filtered_dict['subquery_list'].append(subquery_match) + placeholder = f"" + filtered_dict['filtered_subquery'].append(self.SQLQueryFilter(sqlparse.parse(subquery_match)[0], self.client_id)) + filtered_dict['placeholder_value'].append(placeholder) + + return filtered_dict + - select_block = [] - from_block = [] - where_block = [] + def handle_subquery(self, parsed): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + + select_index = None + from_index = None + where_index = None + end_index = None + + select_block = [] + from_block = [] + where_block = [] + end_keywords_block = [] - def keyword_index(tokens): - nonlocal select_index, from_index, where_index i = 0 while i < len(tokens): token = tokens[i] @@ -26,189 +226,72 @@ def keyword_index(tokens): if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': select_index = i k = i + 1 - while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): k += 1 from_index = k k = from_index + 1 - while k < len(tokens): - if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ + re.match(r'\(\s*SELECT.*?\bWHERE\b.*?\)', str(tokens[k])): where_index = k + elif isinstance(tokens[k], Token) and str(tokens[k]) in end_keywords: + end_index = k break - k += 1 - - i += 1 - - keyword_index(tokens) - from_end = where_index if where_index is not None else len(tokens) - - for j in range(select_index + 1, from_index): # Between SELECT and FROM block - select_block.append(_cleanup_whitespace(str(tokens[j]))) - - for j in range(from_index + 1, from_end): - if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: - from_block.append(tokens[j]) - if where_index: - for j in range(where_index, len(tokens)): - where_block.append(_cleanup_whitespace(str(tokens[j]).strip('WHERE '))) - WHERE_dict = WHERE_subquery(parsed, client_id) - - SELECT_dict = SELECT_subquery(select_block, client_id) - FROM_dict = FROM_subquery(from_index, client_id) - - subquery_dict = { - "subqueries": SELECT_dict['subquery_list'] + FROM_dict['subquery_list'] + WHERE_dict['subquery_list'], - "filtered subqueries": SELECT_dict['filtered_subquery'] + FROM_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'], - "placeholder names": SELECT_dict['placeholder_value'] + FROM_dict['placeholder_value'] + WHERE_dict['placeholder_value'] - } - - for i in range(len(subquery_dict['filtered subqueries'])): - mainquery_str = str(parsed).replace(f"({subquery_dict['subqueries'][i]})", subquery_dict['placeholder names'][i]) if i == 0 else mainquery_str.replace(f"({subquery_dict['subqueries'][i]})", subquery_dict['placeholder names'][i]) - if len(subquery_dict['subqueries']) == 1: - filtered_mainquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(mainquery_str)[0], client_id) # Handle the case where there is only one subquery - - elif i == len(subquery_dict['subqueries']) - 1: - filtered_mainquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(mainquery_str)[0], client_id) # Apply filtering to the main query for the last iteration in case of multiple subqueries - - elif i == 0: - filtered_mainquery = mainquery_str # For the first iteration, just keep the mainquery_str as it is - - for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): - filtered_mainquery = filtered_mainquery.replace(placeholder, f"({str(filtered_subquery)})") - - print(filtered_mainquery) - - -def SELECT_subquery(SELECT_block, client_id): - - select_elements = ' '.join(SELECT_block).strip().split(',') # Split by commas to handle multiple elements in the SELECT block - filtered_dict = { - 'subquery_list': [], - 'filtered_subquery': [], - 'placeholder_value': [] - } - - for i, element in enumerate(select_elements): - element = element.replace('\n', ' ').strip() # Clean up any extra whitespace - - if re.search(r'\bCASE\b((\s+WHEN\b.*?\bTHEN\b.*?)+)(\s+ELSE\b.*)?(?=\s+END\b)', element, re.DOTALL): - for match in re.findall(r'\bWHEN\b.*?\bTHEN\b.*?\bELSE\b.*?(?=\bWHEN\b|\bELSE\b|\bEND\b)', element, re.DOTALL): #Split them into WHEN, THEN and ELSE blocks: # Check for subquery inside WHEN THEN - if re.search(r'\(.*?\bSELECT\b.*?\)', match, re.DOTALL): - filtered_dict['subquery_list'].append(match) - - elif '(' in element and ')' in element: # Find if any element has parenthesis - if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): - filtered_dict['subquery_list'].append(element) - - for i, subquery in enumerate(filtered_dict['subquery_list']): # Apply filters to extracted subqueries - placeholder = f"" - filtered_subquery = SQLQueryFilter._apply_filter_to_single_query(subquery, client_id) - filtered_dict['placeholder_value'].append(placeholder) - filtered_dict['filtered_subquery'].append(filtered_subquery) - - return filtered_dict - -def FROM_subquery(FROM_block, client_id): - join_found = False - join_statement = [] - joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} - set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} - subquery_dict = { - "inline subquery": [], - "join subquery": [], - "set operations": [], - } - - for i, element in enumerate(FROM_block): - if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: # JOINs - join_found = True - if i == 1: - join_statement.append(str(FROM_block[i - 1])) - join_statement.append(str(FROM_block[i + 1])) - elif i > 1: - join_statement.append(str(FROM_block[i + 1])) + k += 1 + i += 1 - elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): - if re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - subquery_dict['set operations'].append(str(element)) - elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - subquery_dict['inline subquery'].append(str(element)) - if join_found: - for stmt in join_statement: - join_statement_str = _cleanup_whitespace(str(stmt)) - if "(" in join_statement_str and ")" in join_statement_str: - if re.match(r'\(\s*\(\s*SELECT.*?FROM.*?\)\s*UNION\s*\(SELECT.*?FROM.*?(AS \w+)?\)\s*\)\s+AS\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): - subquery_dict['set operations'].append(join_statement_str) - elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): - subquery_dict['join subquery'].append(join_statement_str) - - non_setop_filtered_dict = { - 'subquery_list': subquery_dict['inline subquery'] + subquery_dict['join subquery'], - 'filtered_subquery': [], - 'placeholder_value': [] - } - setop_filtered_dict = { - 'subquery_list': subquery_dict['set operations'], - 'filtered_subquery': [], - 'placeholder_value': [] - } - - for nsod in range(len(non_setop_filtered_dict['subquery_list'])): - placeholder = f"" - filtered_subquery = SQLQueryFilter._apply_filter_recursive( - sqlparse.parse(non_setop_filtered_dict['subquery_list'][nsod])[0], client_id) - non_setop_filtered_dict['placeholder_value'].append(placeholder) - non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) - - for sod in range(len(setop_filtered_dict['set operations'])): - placeholder = f"" - non_setop_filtered_dict['subquery_list'].append(subquery_dict['set operations'][sod]) - non_setop_filtered_dict['placeholder_value'].append(placeholder) - filtered_subquery = SQLQueryFilter._handle_set_operation( - sqlparse.parse(subquery_dict['set operations'][sod])[0], client_id) - non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) - - filtered_dict = { - 'subquery_list': non_setop_filtered_dict['subquery_list'] + setop_filtered_dict['subquery_list'], - 'filtered_subquery': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'], - 'placeholder_value': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'] - } - - return filtered_dict - - -def WHERE_subquery(WHERE_block, client_id): - - where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} - where_keyword_pattern = '|'.join(where_keywords) - filtered_dict = { - 'subquery_list': [], - 'filtered_subquery': [], - 'placeholder_value': [] - } - - for i in WHERE_block: - for clause in re.split(r'\bAND\b(?![^()]*\))', i): # Splits into multiple statements if AND exists, else selects the single statement - clause = clause.strip() - - # Check for the presence of any special keyword like IN, NOT IN, EXISTS, ALL, ANY - if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): - filtered_dict['subquery_list'].append(clause) - - # Check for subquery using a SELECT statement in parentheses - elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): - filtered_dict['subquery_list'].append(clause) - - for j in range(len(filtered_dict['subquery_list'])): - placeholder = f"" - filtered_subquery = SQLQueryFilter._apply_filter_recursive(sqlparse.parse(filtered_dict['subquery_list'][j])[0], client_id) - filtered_dict['placeholder_value'].append(placeholder) - filtered_dict['filtered_subquery'].append(filtered_subquery) - - return filtered_dict - - - \ No newline at end of file + where_end = end_index if end_index else len(tokens) + from_end = min( + index for index in [where_index, end_index] if index is not None) if any([where_index, end_index]) \ + else len(tokens) + + for j in range(select_index + 1, from_index): + select_block.append(self._cleanup_whitespace(str(tokens[j]))) + + for j in range(from_index + 1, from_end): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + from_block.append(tokens[j]) + + WHERE_dict = {'subquery_list': [], 'filtered_subquery': [], 'placeholder_value': []} # For cases where WHERE_dict is empty and leads to [UnboundLocalError: cannot access local variable 'WHERE_dict' where it is not associated with a value] + if where_index: + for j in range(where_index, where_end): + where_block.append(self._cleanup_whitespace(str(tokens[j]).strip('WHERE '))) + WHERE_dict = self.WHERE_subquery(where_block) + + END_dict = {'subquery_list': [], 'filtered_subquery': [], 'placeholder_value': []} + if end_index: + for j in range(end_index, len(tokens)): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + end_keywords_block.append(self._cleanup_whitespace(str(tokens[j]))) + END_dict = self.END_subqueries(end_keywords_block) + + SELECT_dict = self.SELECT_subquery(select_block) + FROM_dict = self.FROM_subquery(from_block) + subquery_dict = { + "subqueries": SELECT_dict['subquery_list'] + FROM_dict['subquery_list'] + WHERE_dict['subquery_list'] + END_dict['subquery_list'], + "filtered subqueries": SELECT_dict['filtered_subquery'] + FROM_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'] + END_dict['filtered_subquery'], + "placeholder names": SELECT_dict['placeholder_value'] + FROM_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + END_dict['placeholder_value'] + } + + for i in range(len(subquery_dict['filtered subqueries'])): + pattern = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', subquery_dict['subqueries'][i], re.IGNORECASE) + if pattern: + subquery_with_alias = pattern.group(1) + + mainquery_str = str(parsed).replace(subquery_with_alias, subquery_dict["placeholder names"][i]) if i == 0 \ + else mainquery_str.replace(subquery_with_alias, subquery_dict["placeholder names"][i]) + + if len(subquery_dict['subqueries']) == 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + else: + if i == 0: + filtered_mainquery = mainquery_str + elif i == len(subquery_dict['subqueries']) - 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + + for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): + filtered_mainquery = filtered_mainquery.replace(placeholder, filtered_subquery) + + return filtered_mainquery \ No newline at end of file diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index b9c1812..d294264 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -1,13 +1,15 @@ import re import sqlparse -from sqlparse.sql import IdentifierList, Identifier, Token, TokenList, Parenthesis, Where, Comparison +from sqlparse.sql import IdentifierList, Identifier, Token, Parenthesis, Where, Comparison from sqlparse.tokens import Keyword, DML from typing import List, Dict, Optional + from .nlp_helpers.query_cleanup import _cleanup_whitespace from .nlp_helpers.cte_handler import handle_cte_query from .nlp_helpers.is_cte import is_cte_query from .nlp_helpers.is_subquery import _contains_subquery -import importlib +from .nlp_helpers.subquery_handler import SubqueryHandler + class SQLQueryFilter: def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], case_sensitive: bool = False): @@ -16,8 +18,8 @@ def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], self.case_sensitive = case_sensitive self.filtered_tables = set() self._cleanup_whitespace = _cleanup_whitespace + self.subquery_handler = SubqueryHandler(self._apply_filter_recursive, self._handle_set_operation) - self._handle_subquery = importlib.import_module('subquery_handler') # Fixing circular import error def apply_client_filter(self, sql_query: str, client_id: int) -> str: self.filtered_tables = set() @@ -30,6 +32,7 @@ def apply_client_filter(self, sql_query: str, client_id: int) -> str: result = self._apply_filter_recursive(parsed, client_id) return self._cleanup_whitespace(str(result)) + def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if is_cte_query(parsed): @@ -40,10 +43,11 @@ def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if self._contains_set_operation(parsed) and not _contains_subquery(parsed): return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) elif _contains_subquery(parsed): - return self._handle_subquery(parsed, client_id) + return self.subquery_handler.handle_subquery(parsed) else: return self._apply_filter_to_single_query(str(parsed), client_id) - + + def _contains_set_operation(self, parsed): set_operations = ('UNION', 'INTERSECT', 'EXCEPT') @@ -52,6 +56,7 @@ def _contains_set_operation(self, parsed): return True return False + def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None): set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} statements = [] @@ -89,7 +94,8 @@ def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_nam result = f" {set_operation} ".join(filtered_statements) return result - + + def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: parts = sql_query.split(' GROUP BY ') main_query = parts[0] @@ -104,8 +110,8 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: schema = table_info['schema'] matching_table = self._find_matching_table(table_name, schema) - - if matching_table and matching_table not in self.filtered_tables: + + if matching_table: client_id_column = self.client_tables[matching_table] table_reference = table_alias or table_name filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') @@ -123,6 +129,7 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: return result + group_by + def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: possible_names = [ f"{schema}.{table_name}" if schema else table_name, @@ -134,17 +141,21 @@ def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> return name return None + def _quote_identifier(self, identifier: str) -> str: return f'"{identifier}"' + def _strip_quotes(self, identifier: str) -> str: return identifier.strip('"').strip("'").strip('`') + def _case_insensitive_get(self, dict_obj: Dict[str, str], key: str) -> Optional[str]: if self.case_sensitive: return dict_obj.get(key) return next((v for k, v in dict_obj.items() if k.lower() == key.lower()), None) + def _parse_table_identifier(self, identifier): schema = None alias = None @@ -162,6 +173,7 @@ def _parse_table_identifier(self, identifier): return {'name': name, 'schema': schema, 'alias': alias} + def _extract_tables_info(self, parsed, tables_info=None): if tables_info is None: tables_info = [] @@ -172,6 +184,7 @@ def _extract_tables_info(self, parsed, tables_info=None): return tables_info + def _extract_from_clause_tables(self, parsed, tables_info): from_seen = False for token in parsed.tokens: @@ -191,6 +204,7 @@ def _extract_from_clause_tables(self, parsed, tables_info): tables_info.append(self._parse_table_identifier( parsed.token_next(token)[1])) + def _extract_where_clause_tables(self, parsed, tables_info): where_clause = next( (token for token in parsed.tokens if isinstance(token, Where)), None) @@ -210,6 +224,7 @@ def _extract_where_clause_tables(self, parsed, tables_info): self._extract_from_clause_tables( subquery_parsed, tables_info) + def _extract_cte_tables(self, parsed, tables_info): cte_start = next((i for i, token in enumerate( parsed.tokens) if token.ttype is Keyword and token.value.upper() == 'WITH'), None) @@ -228,6 +243,7 @@ def _extract_cte_tables(self, parsed, tables_info): elif token.ttype is DML and token.value.upper() == 'SELECT': break + def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: parts = sql_query.split(' GROUP BY ') main_query = parts[0] diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 4700066..27bdd3a 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -274,283 +274,5 @@ def test_recursive_cte(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) -class TestSQLQueryFilterAdditional(unittest.TestCase): - def setUp(self): - self.client_tables = { - 'main.orders': 'user_id', - 'orders': 'user_id', - 'main.products': 'company_id', - 'products': 'company_id', - 'inventory.items': 'organization_id', - 'items': 'organization_id', - 'customers': 'customer_id', - 'categories': 'company_id' - } - self.filter = SQLQueryFilter(self.client_tables, schemas=['main', 'inventory']) - - def test_multiple_joins(self): - query = 'SELECT o.id, p.name, c.email FROM orders o JOIN products p ON o.product_id = p.id JOIN customers c ON o.user_id = c.id' - expected = 'SELECT o.id, p.name, c.email FROM orders o JOIN products p ON o.product_id = p.id JOIN customers c ON o.user_id = c.id WHERE "o"."user_id" = 1 AND "p"."company_id" = 1 AND "c"."customer_id" = 1' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_case_statement(self): - query = 'SELECT id, CASE WHEN total_amount > 1000 THEN "High" ELSE "Low" END AS order_value FROM orders' - expected = 'SELECT id, CASE WHEN total_amount > 1000 THEN "High" ELSE "Low" END AS order_value FROM orders WHERE "orders"."user_id" = 1' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_subquery_in_select(self): - query = 'SELECT o.id, (SELECT COUNT(*) FROM products p WHERE p.id = o.product_id) AS product_count FROM orders o' - expected = 'SELECT o.id, (SELECT COUNT(*) FROM products p WHERE p.id = o.product_id AND "p"."company_id" = 1) AS product_count FROM orders o WHERE "o"."user_id" = 1' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_having_clause(self): - query = 'SELECT product_id, COUNT(*) FROM orders GROUP BY product_id HAVING COUNT(*) > 5' - expected = 'SELECT product_id, COUNT(*) FROM orders WHERE "orders"."user_id" = 1 GROUP BY product_id HAVING COUNT(*) > 5' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_order_by_with_limit(self): - query = 'SELECT * FROM orders ORDER BY total_amount DESC LIMIT 10' - expected = 'SELECT * FROM orders WHERE "orders"."user_id" = 1 ORDER BY total_amount DESC LIMIT 10' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_union_with_order_by(self): - query = 'SELECT id FROM orders UNION SELECT id FROM products ORDER BY id' - expected = 'SELECT id FROM orders WHERE "orders"."user_id" = 1 UNION SELECT id FROM products WHERE "products"."company_id" = 1 ORDER BY id' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_subquery_with_aggregate(self): - query = 'SELECT * FROM orders WHERE total_amount > (SELECT AVG(total_amount) FROM orders)' - expected = 'SELECT * FROM orders WHERE total_amount > (SELECT AVG(total_amount) FROM orders WHERE "orders"."user_id" = 1) AND "orders"."user_id" = 1' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_complex_join_with_subquery(self): - query = ''' - SELECT o.id, p.name - FROM orders o - JOIN (SELECT id, name FROM products WHERE price > 100) p ON o.product_id = p.id - WHERE o.status = 'completed' - ''' - expected = ''' - SELECT o.id, p.name - FROM orders o - JOIN (SELECT id, name FROM products WHERE price > 100 AND "products"."company_id" = 1) p ON o.product_id = p.id - WHERE o.status = 'completed' AND "o"."user_id" = 1 - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_complex_nested_subqueries(self): - query = ''' - SELECT * - FROM orders o - WHERE o.product_id IN ( - SELECT id - FROM products - WHERE category_id IN ( - SELECT id - FROM categories - WHERE name LIKE 'Electronics%' - ) - ) AND o.user_id IN ( - SELECT user_id - FROM ( - SELECT user_id, AVG(total_amount) as avg_order - FROM orders - GROUP BY user_id - HAVING AVG(total_amount) > 1000 - ) high_value_customers - ) - ''' - expected = ''' - SELECT * - FROM orders o - WHERE o.product_id IN ( - SELECT id - FROM products - WHERE category_id IN ( - SELECT id - FROM categories - WHERE name LIKE 'Electronics%' - AND "categories"."company_id" = 1 - ) - AND "products"."company_id" = 1 - ) AND o.user_id IN ( - SELECT user_id - FROM ( - SELECT user_id, AVG(total_amount) as avg_order - FROM orders - WHERE "orders"."user_id" = 1 - GROUP BY user_id - HAVING AVG(total_amount) > 1000 - ) high_value_customers - ) - AND "o"."user_id" = 1 - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_group_by_having_order_by(self): - query = ''' - SELECT product_id, COUNT(*) as order_count, SUM(total_amount) as total_sales - FROM orders - GROUP BY product_id - HAVING COUNT(*) > 10 - ORDER BY total_sales DESC - LIMIT 5 - ''' - expected = ''' - SELECT product_id, COUNT(*) as order_count, SUM(total_amount) as total_sales - FROM orders - WHERE "orders"."user_id" = 1 - GROUP BY product_id - HAVING COUNT(*) > 10 - ORDER BY total_sales DESC - LIMIT 5 - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_different_data_types_in_where(self): - query = ''' - SELECT * - FROM orders - WHERE order_date > '2023-01-01' - AND total_amount > 100.50 - AND status IN ('completed', 'shipped') - AND is_priority = TRUE - ''' - expected = ''' - SELECT * - FROM orders - WHERE order_date > '2023-01-01' - AND total_amount > 100.50 - AND status IN ('completed', 'shipped') - AND is_priority = TRUE - AND "orders"."user_id" = 1 - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_multi_schema_query(self): - query = ''' - SELECT o.id, p.name, i.quantity - FROM main.orders o - JOIN main.products p ON o.product_id = p.id - JOIN inventory.items i ON p.id = i.product_id - ''' - expected = ''' - SELECT o.id, p.name, i.quantity - FROM main.orders o - JOIN main.products p ON o.product_id = p.id - JOIN inventory.items i ON p.id = i.product_id - WHERE "o"."user_id" = 1 AND "p"."company_id" = 1 AND "i"."organization_id" = 1 - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - -class TestSQLQueryFilterAdditionalCTE(unittest.TestCase): - def setUp(self): - self.client_tables = { - 'main.orders': 'user_id', - 'orders': 'user_id', - 'main.products': 'company_id', - 'products': 'company_id', - 'inventory.items': 'organization_id', - 'items': 'organization_id', - 'customers': 'customer_id', - 'categories': 'company_id' - } - self.filter = SQLQueryFilter( - self.client_tables, schemas=['main', 'inventory']) - - def assertSQLEqual(self, first, second, msg=None): - def normalize_sql(sql): - # Remove all whitespace - sql = re.sub(r'\s+', '', sql) - # Convert to lowercase - return sql.lower() - - normalized_first = normalize_sql(first) - normalized_second = normalize_sql(second) - self.assertEqual(normalized_first, normalized_second, msg) - - def test_cte_with_union(self): - query = ''' - WITH combined_data AS ( - SELECT id, 'order' AS type FROM orders - UNION ALL - SELECT id, 'product' AS type FROM products - ) - SELECT * FROM combined_data - ''' - expected = ''' - WITH combined_data AS ( - SELECT id, 'order' AS type FROM orders WHERE "orders"."user_id" = 1 - UNION ALL - SELECT id, 'product' AS type FROM products WHERE "products"."company_id" = 1 - ) - SELECT * FROM combined_data - ''' - self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_set_operations_with_cte(self): - query = ''' - WITH order_summary AS ( - SELECT user_id, COUNT(*) as order_count - FROM orders - GROUP BY user_id - ) - SELECT * FROM order_summary - UNION - SELECT company_id as user_id, COUNT(*) as product_count - FROM products - GROUP BY company_id - ''' - expected = ''' - WITH order_summary AS ( - SELECT user_id, COUNT(*) as order_count - FROM orders - WHERE "orders"."user_id" = 1 - GROUP BY user_id - ) - SELECT * FROM order_summary - UNION - SELECT company_id as user_id, COUNT(*) as product_count - FROM products - WHERE "products"."company_id" = 1 - GROUP BY company_id - ''' - self.assertSQLEqual(self.filter.apply_client_filter(query, 1), expected) - - def test_recursive_cte_with_join(self): - query = ''' - WITH RECURSIVE category_tree AS ( - SELECT id, name, parent_id, 0 AS level - FROM categories - WHERE parent_id IS NULL - UNION ALL - SELECT c.id, c.name, c.parent_id, ct.level + 1 - FROM categories c - JOIN category_tree ct ON c.parent_id = ct.id - ) - SELECT ct.*, p.name as product_name - FROM category_tree ct - LEFT JOIN products p ON ct.id = p.category_id - ''' - expected = ''' - WITH RECURSIVE category_tree AS ( - SELECT id, name, parent_id, 0 AS level - FROM categories - WHERE parent_id IS NULL AND "categories"."company_id" = 1 - UNION ALL - SELECT c.id, c.name, c.parent_id, ct.level + 1 - FROM categories c - JOIN category_tree ct ON c.parent_id = ct.id - WHERE "c"."company_id" = 1 - ) - SELECT ct.*, p.name as product_name - FROM category_tree ct - LEFT JOIN products p ON ct.id = p.category_id - WHERE "p"."company_id" = 1 - ''' - self.assertSQLEqual(self.filter.apply_client_filter(query, 1), expected) - - if __name__ == '__main__': unittest.main() From 79ce0f542b523a90f393407b55541767ef7e4494 Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Mon, 7 Oct 2024 10:47:06 +0530 Subject: [PATCH 14/15] Conversion to subquery when filtering single tables has been added --- .../core/nlp_helpers/subquery_handler.py | 224 +++++++++++------- src/dataneuron/core/sql_query_filter.py | 2 +- 2 files changed, 137 insertions(+), 89 deletions(-) diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py index 27e6fdf..a57af07 100644 --- a/src/dataneuron/core/nlp_helpers/subquery_handler.py +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -4,14 +4,14 @@ import re from query_cleanup import _cleanup_whitespace - class SubqueryHandler: - def __init__(self, query_filter=None, setop_query_filter=None): + def __init__(self, query_filter=None, setop_query_filter=None, matching_table_finder=None): self.SQLQueryFilter = query_filter self.SetOP_QueryFilter = setop_query_filter + self._find_matching_table = matching_table_finder self._cleanup_whitespace = _cleanup_whitespace self.client_id = 1 - + self.schemas=['main', 'inventory'] def SELECT_subquery(self, SELECT_block): select_elements = ' '.join(SELECT_block).strip().split(',') @@ -63,78 +63,113 @@ def SELECT_subquery(self, SELECT_block): def FROM_subquery(self, FROM_block): - join_statement = [] - joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} - - subquery_dict = { - "inline subquery": [], - "join subquery": [], - "set operations": [], - } - + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN'} join_found = False - for element in FROM_block: # Separate block to find at least one occurence of JOIN - if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: - join_found = True - break - - for i, element in enumerate(FROM_block): - if join_found: - if i == 1 and isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: - join_statement.append(str(FROM_block[i - 1])) - join_statement.append(str(FROM_block[i + 1])) - elif i > 1 and isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: - join_statement.append(str(FROM_block[i + 1])) - - elif not join_found: - if re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): - if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', str(element), re.IGNORECASE | re.DOTALL): - subquery_dict["set operations"].append(f"({str(element)})") - elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): - subquery_dict['inline subquery'].append(str(element)) - - for stmt in join_statement: - join_statement_str = self._cleanup_whitespace(str(stmt)) - if re.findall(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): - if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): - subquery_dict["set operations"].append(f"({join_statement_str})") - elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): - subquery_dict['join subquery'].append(join_statement_str) - - - non_setop_filtered_dict = { - 'subquery_list': subquery_dict['inline subquery'] + subquery_dict['join subquery'], - 'filtered_subquery': [], - 'placeholder_value': [] - } - setop_filtered_dict = { - 'subquery_list': subquery_dict['set operations'], - 'filtered_subquery': [], - 'placeholder_value': [] - } - - for nsod in range(len(non_setop_filtered_dict['subquery_list'])): - placeholder = f"" - filtered_subquery = self.SQLQueryFilter( sqlparse.parse(re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', non_setop_filtered_dict['subquery_list'][nsod]).group(1) )[0], self.client_id ) - - non_setop_filtered_dict['placeholder_value'].append(placeholder) - non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) - - for sod in range(len(setop_filtered_dict["subquery_list"])): - placeholder = f"" - non_setop_filtered_dict['subquery_list'].append(subquery_dict['set operations'][sod]) - non_setop_filtered_dict['placeholder_value'].append(placeholder) - - filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(re.search(r'^\((.*)\)(\s+AS\s+\w+)?;?$', subquery_dict['set operations'][sod]).group(1))[0], self.client_id) - non_setop_filtered_dict['filtered_subquery'].append(filtered_subquery) - - filtered_dict = { - 'subquery_list': non_setop_filtered_dict['subquery_list'] + setop_filtered_dict['subquery_list'], - 'filtered_subquery': non_setop_filtered_dict['filtered_subquery'] + setop_filtered_dict['filtered_subquery'], - 'placeholder_value': non_setop_filtered_dict['placeholder_value'] + setop_filtered_dict['placeholder_value'] - } + join_statements = [] + exit_early = False + + join_dict = { + "matching_table": [], + "filtered_matching_table": [], + "alias": [] + } + + def _handle_joins(): + for i, token in enumerate(FROM_block): + if join_found and isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: + previous_token = FROM_block[i - 1] if i > 0 else None + next_token = FROM_block[i + 1] if i + 1 < len(FROM_block) else None + if previous_token: + join_statements.append(previous_token.value.strip()) + if next_token: + join_statements.append(next_token.value.strip()) + + for statement in join_statements: + join_statement_str = _cleanup_whitespace(statement) + if self._find_matching_table(join_statement_str, self.schemas): + + filtered_table = self.SQLQueryFilter( + sqlparse.parse(f'SELECT * FROM {join_statement_str}')[0], self.client_id) + + join_dict['filtered_matching_table'].append(f'({filtered_table})') + join_dict['alias'].append(f"AS {join_statement_str}") + join_dict['matching_table'].append(join_statement_str) - return filtered_dict + else: + if re.match(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): + match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', join_statement_str) + inner_parentheses = match.group(1) + start, end = match.span() + alias = join_statement_str[end + 1:] # +1 for WHITESPACEEEE + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['matching_table'].append(join_statement_str) + join_dict['alias'].append(alias) + + elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL) + inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', join_statement_str).group(1) + alias = subquery_match.group(1) + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['matching_table'].append(join_statement_str) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['alias'].append(f"AS {alias}") + + def _not_handle_joins(): + nonlocal exit_early + for token in FROM_block: + FROM_block_str = _cleanup_whitespace(str(token)) + if re.match(r'\(\s*([\s\S]*?)\s*\)', FROM_block_str) and re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', FROM_block_str, re.IGNORECASE | re.DOTALL): + match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', FROM_block_str) + inner_parentheses = match.group(1) + start, end = match.span() + alias = FROM_block_str[end + 1:] # +1 for WHITESPACEEEE + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['matching_table'].append(FROM_block_str) + join_dict['alias'].append(alias) + + elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL): + subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL) + inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', FROM_block_str).group(1) + alias = subquery_match.group(1) + + filtered_subquery = self.SQLQueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['matching_table'].append(FROM_block_str) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['alias'].append(f"AS {alias}") + + elif self._find_matching_table(str(token), self.schemas): + exit_early = True + + for token in FROM_block: + if isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: + join_found = True + break + if join_found: + _handle_joins() + else: + _not_handle_joins() + + if exit_early: + return 0 + else: + reconstructed_from_clause = [] + for token in FROM_block: + if isinstance(token, Token) and token.value.strip() in join_dict["matching_table"]: + table_index = join_dict["matching_table"].index(token.value.strip()) + filtered_table = join_dict["filtered_matching_table"][table_index] + added_alias = join_dict["alias"][table_index] + reconstructed_from_clause.append(f"{filtered_table} {added_alias}") + else: + reconstructed_from_clause.append(token.value.strip()) + + reconstructed_query = " ".join(reconstructed_from_clause) + return reconstructed_query def WHERE_subquery(self, WHERE_block): @@ -230,7 +265,8 @@ def handle_subquery(self, parsed): k += 1 from_index = k - k = from_index + 1 + + k += 1 while k < len(tokens): if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ re.match(r'\(\s*SELECT.*?\bWHERE\b.*?\)', str(tokens[k])): @@ -266,32 +302,44 @@ def handle_subquery(self, parsed): if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: end_keywords_block.append(self._cleanup_whitespace(str(tokens[j]))) END_dict = self.END_subqueries(end_keywords_block) - + SELECT_dict = self.SELECT_subquery(select_block) - FROM_dict = self.FROM_subquery(from_block) subquery_dict = { - "subqueries": SELECT_dict['subquery_list'] + FROM_dict['subquery_list'] + WHERE_dict['subquery_list'] + END_dict['subquery_list'], - "filtered subqueries": SELECT_dict['filtered_subquery'] + FROM_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'] + END_dict['filtered_subquery'], - "placeholder names": SELECT_dict['placeholder_value'] + FROM_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + END_dict['placeholder_value'] + "subqueries": SELECT_dict['subquery_list'] + WHERE_dict['subquery_list'] + END_dict['subquery_list'], + "filtered subqueries": SELECT_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'] + END_dict['filtered_subquery'], + "placeholder names": SELECT_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + END_dict['placeholder_value'] } + FROM_filtering = self.FROM_subquery(from_block) + for i in range(len(subquery_dict['filtered subqueries'])): pattern = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', subquery_dict['subqueries'][i], re.IGNORECASE) if pattern: subquery_with_alias = pattern.group(1) - mainquery_str = str(parsed).replace(subquery_with_alias, subquery_dict["placeholder names"][i]) if i == 0 \ else mainquery_str.replace(subquery_with_alias, subquery_dict["placeholder names"][i]) - - if len(subquery_dict['subqueries']) == 1: - filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + + + if FROM_filtering == 0: + if len(subquery_dict['subqueries']) == 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + else: + if i == 0: + filtered_mainquery = mainquery_str + elif i == len(subquery_dict['subqueries']) - 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) else: - if i == 0: - filtered_mainquery = mainquery_str - elif i == len(subquery_dict['subqueries']) - 1: - filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + from_start = mainquery_str.upper().find('FROM') + where_start = mainquery_str.upper().find('WHERE') + + if where_start == -1: # If there's no WHERE clause + next_clause_starts = [mainquery_str.upper().find(clause) for clause in ['GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT'] if mainquery_str.upper().find(clause) != -1] + where_start = min(next_clause_starts) if next_clause_starts else len(mainquery_str) + + part_to_replace = mainquery_str[from_start:where_start].strip() + filtered_mainquery = mainquery_str.replace(part_to_replace, f"FROM {FROM_filtering}") for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): filtered_mainquery = filtered_mainquery.replace(placeholder, filtered_subquery) - + return filtered_mainquery \ No newline at end of file diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index d294264..aebf140 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -18,7 +18,7 @@ def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], self.case_sensitive = case_sensitive self.filtered_tables = set() self._cleanup_whitespace = _cleanup_whitespace - self.subquery_handler = SubqueryHandler(self._apply_filter_recursive, self._handle_set_operation) + self.subquery_handler = SubqueryHandler(self._apply_filter_recursive, self._handle_set_operation, self._find_matching_table) def apply_client_filter(self, sql_query: str, client_id: int) -> str: From 8b7a40ec2bdee6b97d117e4a7ddaf3783aea1ddf Mon Sep 17 00:00:00 2001 From: nikhil3303 Date: Mon, 7 Oct 2024 22:42:23 +0530 Subject: [PATCH 15/15] Remembered to add the same logic to one more similar block --- .../core/nlp_helpers/subquery_handler.py | 126 ++++++++++-------- 1 file changed, 71 insertions(+), 55 deletions(-) diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py index a57af07..fda8d36 100644 --- a/src/dataneuron/core/nlp_helpers/subquery_handler.py +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -1,9 +1,10 @@ import sqlparse -from sqlparse.sql import Token +from sqlparse.sql import Token, Identifier from sqlparse.tokens import Keyword, DML, Whitespace, Newline import re from query_cleanup import _cleanup_whitespace + class SubqueryHandler: def __init__(self, query_filter=None, setop_query_filter=None, matching_table_finder=None): self.SQLQueryFilter = query_filter @@ -13,6 +14,7 @@ def __init__(self, query_filter=None, setop_query_filter=None, matching_table_fi self.client_id = 1 self.schemas=['main', 'inventory'] + def SELECT_subquery(self, SELECT_block): select_elements = ' '.join(SELECT_block).strip().split(',') filtered_dict = { @@ -33,29 +35,25 @@ def SELECT_subquery(self, SELECT_block): # Process WHEN-THEN pairs for when, then in when_then_else_blocks: - if re.search(r'\(.*?\bSELECT\b.*?\)', when, re.DOTALL): # Check if WHEN has a subquery + if re.search(r'\(.*?\bSELECT\b.*?\)', when, re.DOTALL): #WHEN has a subquery filtered_dict['subquery_list'].append(when) - if re.search(r'\(.*?\bSELECT\b.*?\)', then, re.DOTALL): # Check if THEN has a subquery + if re.search(r'\(.*?\bSELECT\b.*?\)', then, re.DOTALL): #THEN has a subquery filtered_dict['subquery_list'].append(then) - # Process ELSE clause if exists - if else_clause and re.search(r'\(.*?\bSELECT\b.*?\)', else_clause.group(1), re.DOTALL): + if else_clause and re.search(r'\(.*?\bSELECT\b.*?\)', else_clause.group(1), re.DOTALL): #ELSE has a subquery filtered_dict['subquery_list'].append(else_clause.group(1)) - # Handle simple subqueries outside CASE block elif '(' in element and ')' in element: if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): filtered_dict['subquery_list'].append(element) - # Create placeholders and filter subqueries for i, subquery in enumerate(filtered_dict['subquery_list']): placeholder = f"" + filtered_subquery = self.SQLQueryFilter( sqlparse.parse( - re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', subquery).group(1) - )[0], - self.client_id - ) + re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', subquery).group(1))[0], self.client_id) + filtered_dict['placeholder_value'].append(placeholder) filtered_dict['filtered_subquery'].append(filtered_subquery) @@ -75,6 +73,7 @@ def FROM_subquery(self, FROM_block): } def _handle_joins(): + alias = None for i, token in enumerate(FROM_block): if join_found and isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: previous_token = FROM_block[i - 1] if i > 0 else None @@ -86,14 +85,26 @@ def _handle_joins(): for statement in join_statements: join_statement_str = _cleanup_whitespace(statement) - if self._find_matching_table(join_statement_str, self.schemas): + + for t in sqlparse.parse(join_statement_str)[0].tokens: + if isinstance(t, Identifier): + alias = t.get_alias() + name = t.get_real_name() + + if alias and self._find_matching_table(str(name), self.schemas) or \ + self._find_matching_table(join_statement_str, self.schemas): filtered_table = self.SQLQueryFilter( sqlparse.parse(f'SELECT * FROM {join_statement_str}')[0], self.client_id) - join_dict['filtered_matching_table'].append(f'({filtered_table})') - join_dict['alias'].append(f"AS {join_statement_str}") - join_dict['matching_table'].append(join_statement_str) + + if alias: + join_dict['alias'].append(f"AS {alias}") + join_dict['matching_table'].append(join_statement_str) + else: + + join_dict['alias'].append(f"AS {join_statement_str}") + join_dict['matching_table'].append(join_statement_str) else: if re.match(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): @@ -111,12 +122,14 @@ def _handle_joins(): elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL): subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL) inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', join_statement_str).group(1) - alias = subquery_match.group(1) + alias = subquery_match.group(1) if subquery_match else '' + start, end = subquery_match.span() filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['matching_table'].append(join_statement_str) join_dict['filtered_matching_table'].append(f'({filtered_subquery})') - join_dict['alias'].append(f"AS {alias}") + join_dict['alias'].append(f"{alias}" if alias else "") def _not_handle_joins(): nonlocal exit_early @@ -136,15 +149,22 @@ def _not_handle_joins(): elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL): subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL) inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', FROM_block_str).group(1) - alias = subquery_match.group(1) + alias_match = re.search(r'\)\s*(AS\s+\w+)?', FROM_block_str, re.IGNORECASE) + alias = alias_match.group(1) if alias_match and alias_match.group(1) else '' + start, end = subquery_match.span() - filtered_subquery = self.SQLQueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['matching_table'].append(FROM_block_str) join_dict['filtered_matching_table'].append(f'({filtered_subquery})') - join_dict['alias'].append(f"AS {alias}") + join_dict['alias'].append(alias) - elif self._find_matching_table(str(token), self.schemas): - exit_early = True + else: + for t in sqlparse.parse(FROM_block_str)[0].tokens: + if isinstance(t, Identifier): + name = t.get_real_name() + if self._find_matching_table(str(name), self.schemas): + exit_early = True for token in FROM_block: if isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: @@ -198,10 +218,10 @@ def WHERE_subquery(self, WHERE_block): return filtered_dict + def END_subqueries(self, end_keywords_block): - end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} - # Dictionary to hold the result + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} filtered_dict = { 'subquery_list': [], 'filtered_subquery': [], @@ -219,13 +239,13 @@ def END_subqueries(self, end_keywords_block): if count >= 1: # If there is at least one end keyword for i in range(len(indices)): - start_idx = indices[i] # Start and end indices of each block + start_idx = indices[i] + if i < len(indices) - 1: end_idx = indices[i + 1] # Until the next keyword else: end_idx = len(end_keywords_block) # Until the end of the block - # Extract the block between start_idx and end_idx endsubquery_block = end_keywords_block[start_idx:end_idx] endsubquery_block_str = ' '.join(endsubquery_block) @@ -265,7 +285,6 @@ def handle_subquery(self, parsed): k += 1 from_index = k - k += 1 while k < len(tokens): if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ @@ -304,42 +323,39 @@ def handle_subquery(self, parsed): END_dict = self.END_subqueries(end_keywords_block) SELECT_dict = self.SELECT_subquery(select_block) + FROM_filtering = self.FROM_subquery(from_block) + subquery_dict = { "subqueries": SELECT_dict['subquery_list'] + WHERE_dict['subquery_list'] + END_dict['subquery_list'], "filtered subqueries": SELECT_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'] + END_dict['filtered_subquery'], "placeholder names": SELECT_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + END_dict['placeholder_value'] } - FROM_filtering = self.FROM_subquery(from_block) - - for i in range(len(subquery_dict['filtered subqueries'])): - pattern = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', subquery_dict['subqueries'][i], re.IGNORECASE) - if pattern: - subquery_with_alias = pattern.group(1) - mainquery_str = str(parsed).replace(subquery_with_alias, subquery_dict["placeholder names"][i]) if i == 0 \ - else mainquery_str.replace(subquery_with_alias, subquery_dict["placeholder names"][i]) - - - if FROM_filtering == 0: - if len(subquery_dict['subqueries']) == 1: - filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) - else: - if i == 0: - filtered_mainquery = mainquery_str - elif i == len(subquery_dict['subqueries']) - 1: + if FROM_filtering == 0: + for i in range(len(subquery_dict['filtered subqueries'])): + pattern = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', subquery_dict['subqueries'][i], re.IGNORECASE) + if pattern: + subquery_with_alias = pattern.group(1) + mainquery_str = str(parsed).replace(subquery_with_alias, subquery_dict["placeholder names"][i]) if i == 0 \ + else mainquery_str.replace(subquery_with_alias, subquery_dict["placeholder names"][i]) + + if len(subquery_dict['subqueries']) == 1: filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) - else: - from_start = mainquery_str.upper().find('FROM') - where_start = mainquery_str.upper().find('WHERE') + else: + if i == 0: + filtered_mainquery = mainquery_str + elif i == len(subquery_dict['subqueries']) - 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + else: + mainquery_str = str(parsed) - if where_start == -1: # If there's no WHERE clause - next_clause_starts = [mainquery_str.upper().find(clause) for clause in ['GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT'] if mainquery_str.upper().find(clause) != -1] - where_start = min(next_clause_starts) if next_clause_starts else len(mainquery_str) - - part_to_replace = mainquery_str[from_start:where_start].strip() - filtered_mainquery = mainquery_str.replace(part_to_replace, f"FROM {FROM_filtering}") + from_start = mainquery_str.upper().find('FROM') + where_start = mainquery_str.upper().find('WHERE') + + part_to_replace = mainquery_str[from_start:where_start].strip() + filtered_mainquery = mainquery_str.replace(part_to_replace, f"FROM {FROM_filtering}") for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): - filtered_mainquery = filtered_mainquery.replace(placeholder, filtered_subquery) - + filtered_mainquery = filtered_mainquery.replace(placeholder, filtered_subquery) + return filtered_mainquery \ No newline at end of file