Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Jan 13, 2025
1 parent b35f721 commit 46fb145
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2444,10 +2444,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// Extract the function input project list from the SQL function plan and
// inline the SQL function expression.
plan match {
case Project(body :: Nil, Project(aliases, _: OneRowRelation)) =>
val inputs = aliases.map(stripOuterReference)
projectList ++= inputs
SQLScalarFunction(f.function, inputs.map(_.toAttribute), body)
case Project(body :: Nil, Project(aliases, _: LocalRelation)) =>
projectList ++= aliases
SQLScalarFunction(f.function, aliases.map(_.toAttribute), body)
case o =>
throw new AnalysisException(
errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.catalog.SQLFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, SQL_SCALAR_FUNCTION, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -45,7 +44,7 @@ case class SQLFunctionExpression(
* A wrapper node for a SQL scalar function expression.
*/
case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression], child: Expression)
extends UnaryExpression with UnaryLike[Expression] with Unevaluable {
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override def toString: String = s"${function.name}(${inputs.mkString(", ")})"
override def sql: String = s"${function.name}(${inputs.map(_.sql).mkString(", ")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression, OuterReference, ScalarSubquery, UpCast}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression, ScalarSubquery, UpCast}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LogicalPlan, NamedParametersSupport, OneRowRelation, Project, SubqueryAlias, View}
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LocalRelation, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias, View}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils}
import org.apache.spark.sql.connector.catalog.CatalogManager
Expand Down Expand Up @@ -1577,39 +1577,27 @@ class SessionCatalog(
*
* SELECT area(a, b) FROM t;
*
* Analyzed SQL function plan:
* SQL function plan:
*
* Project [CAST(width * height AS DOUBLE) AS area]
* +- Project [CAST(outer(a) AS DOUBLE) AS width, CAST(outer(b AS DOUBLE) AS height]
* +- OneRowRelation
*
* Analyzed plan:
*
* Project [area(width, height) AS area]
* +- Project [a, b, CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
* +- Relation [a, b]
* +- Project [CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
* +- LocalRelation [a, b]
*
* Example scalar SQL function with a subquery:
*
* CREATE FUNCTION foo(x INT) RETURNS INT
* RETURN SELECT SUM(a) FROM t WHERE x = a;
* RETURN SELECT SUM(b) FROM t WHERE x = a;
*
* SELECT foo(a) FROM t;
*
* Analyzed SQL function plan:
* SQL function plan:
*
* Project [scalar-subquery AS foo]
* : +- Aggregate [] [sum(a)]
* : +- Aggregate [] [sum(b)]
* : +- Filter [outer(x) = a]
* : +- Relation [a, b]
* +- Project [CAST(outer(a) AS INT) AS x]
* +- OneRowRelation
*
* Analyzed plan:
*
* Project [foo(x) AS foo]
* +- Project [a, b, CAST(a AS INT) AS x]
* +- Relation [a, b]
* +- Project [CAST(a AS INT) AS x]
* +- LocalRelation [a, b]
*/
def makeSQLFunctionPlan(
name: String,
Expand Down Expand Up @@ -1657,22 +1645,16 @@ class SessionCatalog(

paddedInput.zip(param.fields).map {
case (expr, param) =>
// Add outer references to all attributes and outer references in the function input.
// Outer references also need to be wrapped because the function input may already
// contain outer references.
val outer = expr.transform {
case a: Attribute => OuterReference(a)
case o: OuterReference => OuterReference(o)
}
Alias(Cast(outer, param.dataType), param.name)(
Alias(Cast(expr, param.dataType), param.name)(
qualifier = qualifier,
// mark the alias as function input
explicitMetadata = Some(metaForFuncInputAlias))
}
}.getOrElse(Nil)

val body = if (query.isDefined) ScalarSubquery(query.get) else expression.get
Project(Alias(Cast(body, returnType), funcName)() :: Nil, Project(inputs, OneRowRelation()))
Project(Alias(Cast(body, returnType), funcName)() :: Nil,
Project(inputs, LocalRelation(inputs.flatMap(_.references))))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSparkSession

/**
* Test suite for SQL user-defined functions (UDFs).
*/
class SQLFunctionSuite extends QueryTest with SharedSparkSession {
import testImplicits._

protected override def beforeAll(): Unit = {
super.beforeAll()
Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
}

test("SQL scalar function") {
withUserDefinedFunction("area" -> false) {
sql(
"""
|CREATE FUNCTION area(width DOUBLE, height DOUBLE)
|RETURNS DOUBLE
|RETURN width * height
|""".stripMargin)
checkAnswer(sql("SELECT area(1, 2)"), Row(2))
checkAnswer(sql("SELECT area(a, b) FROM t"), Seq(Row(0), Row(2)))
}
}

test("SQL scalar function with subquery in the function body") {
withUserDefinedFunction("foo" -> false) {
withTable("tbl") {
sql("CREATE TABLE tbl AS SELECT * FROM VALUES (1, 2), (1, 3), (2, 3) t(a, b)")
sql(
"""
|CREATE FUNCTION foo(x INT) RETURNS INT
|RETURN SELECT SUM(b) FROM tbl WHERE x = a;
|""".stripMargin)
checkAnswer(sql("SELECT foo(1)"), Row(5))
checkAnswer(sql("SELECT foo(a) FROM t"), Seq(Row(null), Row(5)))
}
}
}
}

0 comments on commit 46fb145

Please sign in to comment.