From a182eeb9c1750355a412e0eedf6dd83ef9d650b3 Mon Sep 17 00:00:00 2001 From: Cade Markegard Date: Mon, 25 Mar 2024 12:50:30 -0700 Subject: [PATCH] [RK-23] Apply OpenGov patches to the OpenGov Calcite fork for version 1.36 --- .../org/apache/calcite/sql/SqlDialect.java | 23 ++++++++++++ .../sql/dialect/RedshiftSqlDialect.java | 4 +++ .../rel/rel2sql/RelToSqlConverterTest.java | 35 ++++++++++++++++--- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java index 594048556010..d5514a7161a3 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java @@ -447,6 +447,22 @@ public void unparseCall(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { SqlOperator operator = call.getOperator(); switch (call.getKind()) { + case IS_TRUE: + case IS_NOT_FALSE: + case IS_FALSE: + case IS_NOT_TRUE: + if (call.operand(0) instanceof SqlBasicCall + && !isOperandParensCandidate(writer, call.operand(0), operator)) { + // Wrap in parentheses the operand associated to any of the above functions + // if the operand is a SqlBasicCall and will not have parenthesis applied when unparsed + final SqlWriter.Frame frame = writer.startList("(", ")"); + call.operand(0).unparse(writer, operator.getLeftPrec(), operator.getRightPrec()); + writer.endList(frame); + writer.keyword(operator.getName()); + break; + } + operator.unparse(writer, call, leftPrec, rightPrec); + break; case ROW: // Remove the ROW keyword if the dialect does not allow that. if (!getConformance().allowExplicitRowValueConstructor()) { @@ -467,6 +483,13 @@ public void unparseCall(SqlWriter writer, SqlCall call, int leftPrec, } } + private boolean isOperandParensCandidate(SqlWriter writer, SqlBasicCall operand, + SqlOperator operator) { + return operator.getLeftPrec() > operand.getOperator().getLeftPrec() + || operator.getRightPrec() >= operand.getOperator().getRightPrec() + || writer.isAlwaysUseParentheses() && operand.isA(SqlKind.EXPRESSION); + } + public void unparseDateTimeLiteral(SqlWriter writer, SqlAbstractDateTimeLiteral literal, int leftPrec, int rightPrec) { writer.literal(literal.toString()); diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java index 4e94977ce306..0324db47f2e4 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java @@ -114,4 +114,8 @@ public RedshiftSqlDialect(Context context) { @Override public boolean supportsAliasedValues() { return false; } + + @Override public boolean supportsAggregateFunctionFilter() { + return false; + } } diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index b73cc8016cfd..c429c0955fda 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -314,20 +314,20 @@ private static String toSql(RelNode root, SqlDialect dialect, + "where \"product_id\" > 0\n" + "group by \"product_id\""; final String expectedDefault = "SELECT" - + " SUM(\"shelf_width\") FILTER (WHERE \"net_weight\" > 0 IS TRUE)," + + " SUM(\"shelf_width\") FILTER (WHERE (\"net_weight\" > 0) IS TRUE)," + " SUM(\"shelf_width\")\n" + "FROM \"foodmart\".\"product\"\n" + "WHERE \"product_id\" > 0\n" + "GROUP BY \"product_id\""; final String expectedBigQuery = "SELECT" - + " SUM(CASE WHEN net_weight > 0 IS TRUE" + + " SUM(CASE WHEN (net_weight > 0) IS TRUE" + " THEN shelf_width ELSE NULL END), " + "SUM(shelf_width)\n" + "FROM foodmart.product\n" + "WHERE product_id > 0\n" + "GROUP BY product_id"; final String expectedFirebolt = "SELECT" - + " SUM(CASE WHEN \"net_weight\" > 0 IS TRUE" + + " SUM(CASE WHEN (\"net_weight\" > 0) IS TRUE" + " THEN \"shelf_width\" ELSE NULL END), " + "SUM(\"shelf_width\")\n" + "FROM \"foodmart\".\"product\"\n" @@ -366,6 +366,33 @@ private static String toSql(RelNode root, SqlDialect dialect, .withBigQuery().ok(expectedBigQuery); } + @Test void testPivotToSqlWhenFilterIsNotSupported() { + String query = "select * from (\n" + + " select \"brand_name\", \"net_weight\", \"product_id\"\n" + + " from \"foodmart\".\"product\")\n" + + " pivot (sum(\"net_weight\") as w, count(*) as c\n" + + " for (\"brand_name\") in ('a', 'b'))"; + final String expected = "SELECT \"product_id\"," + + " SUM(\"net_weight\") FILTER (WHERE (\"brand_name\" = 'a') IS TRUE) AS \"'a'_W\"," + + " COUNT(*) FILTER (WHERE (\"brand_name\" = 'a') IS TRUE) AS \"'a'_C\"," + + " SUM(\"net_weight\") FILTER (WHERE (\"brand_name\" = 'b') IS TRUE) AS \"'b'_W\"," + + " COUNT(*) FILTER (WHERE (\"brand_name\" = 'b') IS TRUE) AS \"'b'_C\"\n" + + "FROM \"foodmart\".\"product\"\n" + + "GROUP BY \"product_id\""; + // Redshift does not support FILTER + final String expectedRedshiftSql = "SELECT \"product_id\"," + + " SUM(CASE WHEN (\"brand_name\" = 'a') IS TRUE " + + "THEN \"net_weight\" ELSE NULL END) AS \"'a'_W\"," + + " COUNT(CASE WHEN (\"brand_name\" = 'a') IS TRUE THEN 1 ELSE NULL END) AS \"'a'_C\"," + + " SUM(CASE WHEN (\"brand_name\" = 'b') IS TRUE " + + "THEN \"net_weight\" ELSE NULL END) AS \"'b'_W\"," + + " COUNT(CASE WHEN (\"brand_name\" = 'b') IS TRUE THEN 1 ELSE NULL END) AS \"'b'_C\"\n" + + "FROM \"foodmart\".\"product\"\n" + + "GROUP BY \"product_id\""; + sql(query).ok(expected) + .withRedshift().ok(expectedRedshiftSql); + } + @Test void testSimpleSelectQueryFromProductTable() { String query = "select \"product_id\", \"product_class_id\" from \"product\""; final String expected = "SELECT \"product_id\", \"product_class_id\"\n" @@ -6282,7 +6309,7 @@ private void checkLiteral2(String expression, String expected) { + "within group (order by \"net_weight\" desc) filter (where \"net_weight\" > 0)" + "from \"product\" group by \"product_class_id\""; final String expected = "SELECT \"product_class_id\", COLLECT(\"net_weight\") " - + "FILTER (WHERE \"net_weight\" > 0 IS TRUE) " + + "FILTER (WHERE (\"net_weight\" > 0) IS TRUE) " + "WITHIN GROUP (ORDER BY \"net_weight\" DESC)\n" + "FROM \"foodmart\".\"product\"\n" + "GROUP BY \"product_class_id\"";