Skip to content

Commit

Permalink
[CALCITE-3697] Implement BITCOUNT scalar function
Browse files Browse the repository at this point in the history
* All libraries now have a BITCOUNT() function that supports integer and binary values
* The MySQL and BigQuery libraries also hav BIT_COUNT() which is an alias for BITCOUNT()
* The MySQL version of BIT_COUNT() also supports decimal values (only looks at the integer portion)
* BITCOUNT() counts all of the bits set in an integer or bytestring
* Return NULL if the argument is NULL
  • Loading branch information
normanj-bitquill authored and mihaibudiu committed Sep 9, 2024
1 parent 9723741 commit 67405c3
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.ATANH;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BITAND_AGG;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BITOR_AGG;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BIT_COUNT_BIG_QUERY;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BIT_COUNT_MYSQL;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BIT_GET;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BIT_LENGTH;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.BOOLAND_AGG;
Expand Down Expand Up @@ -339,6 +341,7 @@
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASIN;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ATAN;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ATAN2;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BITCOUNT;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BIT_AND;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BIT_OR;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BIT_XOR;
Expand Down Expand Up @@ -754,6 +757,11 @@ Builder populate() {
/** Second step of population. The {@code populate} method grew too large,
* and we factored this out. Feel free to decompose further. */
Builder populate2() {
// bitwise
defineMethod(BITCOUNT, BuiltInMethod.BITCOUNT.method, NullPolicy.STRICT);
defineMethod(BIT_COUNT_BIG_QUERY, BuiltInMethod.BITCOUNT.method, NullPolicy.STRICT);
defineMethod(BIT_COUNT_MYSQL, BuiltInMethod.BITCOUNT.method, NullPolicy.STRICT);

// datetime
map.put(DATETIME_PLUS, new DatetimeArithmeticImplementor());
map.put(MINUS_DATE, new DatetimeArithmeticImplementor());
Expand Down
37 changes: 37 additions & 0 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2840,6 +2840,43 @@ public static ByteString bitAnd(ByteString b0, ByteString b1) {
return binaryOperator(b0, b1, (x, y) -> (byte) (x & y));
}

/** Helper function for implementing <code>BITCOUNT</code>. Counts the number
* of bits set in an integer value. */
public static long bitCount(long b) {
return Long.bitCount(b);
}

private static final BigDecimal BITCOUNT_MAX = new BigDecimal(2).pow(64)
.subtract(new BigDecimal(1));
private static final BigDecimal BITCOUNT_MIN = new BigDecimal(2).pow(63).negate();

/** Helper function for implementing <code>BITCOUNT</code>. Counts the number
* of bits set in the integer portion of a decimal value. */
public static long bitCount(BigDecimal b) {
final int comparison = b.compareTo(BITCOUNT_MAX);
if (comparison < 0) {
if (b.compareTo(BITCOUNT_MIN) <= 0) {
return 1;
} else {
return bitCount(b.setScale(0, RoundingMode.DOWN).longValue());
}
} else if (comparison == 0) {
return 64;
} else {
return 63;
}
}

/** Helper function for implementing <code>BITCOUNT</code>. Counts the number
* of bits set in a ByteString value. */
public static long bitCount(ByteString b) {
long bitsSet = 0;
for (int i = 0; i < b.length(); i++) {
bitsSet += Integer.bitCount(0xff & b.byteAt(i));
}
return bitsSet;
}

/** Bitwise function <code>BIT_OR</code> applied to integer values. */
public static long bitOr(long b0, long b1) {
return b0 | b1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import static org.apache.calcite.sql.fun.SqlLibrary.REDSHIFT;
import static org.apache.calcite.sql.fun.SqlLibrary.SNOWFLAKE;
import static org.apache.calcite.sql.fun.SqlLibrary.SPARK;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BITCOUNT;
import static org.apache.calcite.sql.type.OperandTypes.STRING_FIRST_OBJECT_REPEAT;
import static org.apache.calcite.sql.type.OperandTypes.STRING_FIRST_STRING_ARRAY_REPEAT;
import static org.apache.calcite.util.Static.RESOURCE;
Expand Down Expand Up @@ -1123,6 +1124,21 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
public static final SqlSpecialOperator NOT_RLIKE =
new SqlLikeOperator("NOT RLIKE", SqlKind.RLIKE, true, true);

/** Alias for {@link SqlStdOperatorTable#BITCOUNT}. */
@LibraryOperator(libraries = {BIG_QUERY, SPARK})
public static final SqlFunction BIT_COUNT_BIG_QUERY =
BITCOUNT.withName("BIT_COUNT");

@LibraryOperator(libraries = {MYSQL})
public static final SqlFunction BIT_COUNT_MYSQL =
new SqlFunction(
"BIT_COUNT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.BIGINT_NULLABLE,
null,
OperandTypes.NUMERIC.or(OperandTypes.BINARY),
SqlFunctionCategory.NUMERIC);

/** The "CONCAT(arg, ...)" function that concatenates strings.
* For example, "CONCAT('a', 'bc', 'd')" returns "abcd".
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,11 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
public static final SqlAggFunction VARIANCE =
new SqlAvgAggFunction("VARIANCE", SqlKind.VAR_SAMP);

public static final SqlBasicFunction BITCOUNT =
SqlBasicFunction
.create("BITCOUNT", ReturnTypes.BIGINT_NULLABLE,
OperandTypes.INTEGER.or(OperandTypes.BINARY), SqlFunctionCategory.NUMERIC);

/**
* <code>BIT_AND</code> aggregate function.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ public enum BuiltInMethod {
LT(SqlFunctions.class, "lt", boolean.class, boolean.class),
GT(SqlFunctions.class, "gt", boolean.class, boolean.class),
BIT_AND(SqlFunctions.class, "bitAnd", long.class, long.class),
BITCOUNT(SqlFunctions.class, "bitCount", BigDecimal.class),
BIT_OR(SqlFunctions.class, "bitOr", long.class, long.class),
BIT_XOR(SqlFunctions.class, "bitXor", long.class, long.class),
MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION(ModifiableTable.class,
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/resources/sql/functions.iq
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@
!use mysqlfunc
!set outputformat mysql

# BIT Functions

# BIT_COUNT
select bit_count(8);
+--------+
| EXPR$0 |
+--------+
| 1 |
+--------+
(1 row)

!ok

# MATH Functions

# CBRT
Expand Down
4 changes: 4 additions & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,10 @@ In the following:
| * | ATANH(numeric) | Returns the inverse hyperbolic tangent of *numeric*
| f | BITAND_AGG(value) | Equivalent to `BIT_AND(value)`
| f | BITOR_AGG(value) | Equivalent to `BIT_OR(value)`
| * | BITCOUNT(value) | Returns the bitwise COUNT of *value* or NULL if *value* is NULL. *value* must be and integer or binary value.
| b s | BIT_COUNT(integer) | Returns the bitwise COUNT of *integer* or NULL if *integer* is NULL
| m | BIT_COUNT(numeric) | Returns the bitwise COUNT of the integer portion of *numeric* or NULL if *numeric* is NULL
| b m s | BIT_COUNT(binary) | Returns the bitwise COUNT of *binary* or NULL if *binary* is NULL
| s | BIT_LENGTH(binary) | Returns the bit length of *binary*
| s | BIT_LENGTH(string) | Returns the bit length of *string*
| s | BIT_GET(value, position) | Returns the bit (0 or 1) value at the specified *position* of numeric *value*. The positions are numbered from right to left, starting at zero. The *position* argument cannot be negative
Expand Down
65 changes: 65 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15531,6 +15531,71 @@ void checkBitAnd(SqlOperatorFixture f0, FunctionAlias functionAlias) {
f0.forEachLibrary(list(functionAlias.libraries), consumer);
}

@Test void testBitCountFunc() {
checkBitCount(SqlStdOperatorTable.BITCOUNT, null, false);
}

@Test void testBitCountBigQueryFunc() {
checkBitCount(SqlLibraryOperators.BIT_COUNT_BIG_QUERY,
list(SqlLibrary.BIG_QUERY, SqlLibrary.SPARK), false);
}

@Test void testBitCountMySQLFunc() {
checkBitCount(SqlLibraryOperators.BIT_COUNT_MYSQL, list(SqlLibrary.MYSQL), true);
}

void checkBitCount(SqlFunction function, @Nullable List<SqlLibrary> libraries,
boolean testDecimal) {
final SqlOperatorFixture f0 = fixture();
f0.setFor(function, VmName.EXPAND);
final String functionName = function.getName();
final Consumer<SqlOperatorFixture> consumer = f -> {
f.checkFails(functionName + "(^*^)", "Unknown identifier '\\*'", false);
f.checkType(functionName + "(1)", "BIGINT NOT NULL");
f.checkType(functionName + "(CAST(2 AS TINYINT))", "BIGINT NOT NULL");
f.checkType(functionName + "(CAST(2 AS SMALLINT))", "BIGINT NOT NULL");
f.checkFails(
"^" + functionName + "()^",
"Invalid number of arguments to function '" + functionName
+ "'. Was expecting 1 arguments",
false);
f.checkFails(
"^" + functionName + "(1, 2)^",
"Invalid number of arguments to function '" + functionName
+ "'. Was expecting 1 arguments",
false);
f.checkScalar(functionName + "(8)", "1", "BIGINT NOT NULL");
f.checkScalar(functionName + "(CAST(x'ad' AS BINARY(1)))", "5", "BIGINT NOT NULL");
f.checkScalar(functionName + "(CAST(x'ad' AS VARBINARY(1)))", "5", "BIGINT NOT NULL");
f.checkScalar(functionName + "(-1)", "64", "BIGINT NOT NULL");
f.checkNull(functionName + "(cast(NULL as TINYINT))");
f.checkNull(functionName + "(cast(NULL as BINARY))");
f.checkNull(functionName + "(NULL)");
if (testDecimal) {
f.checkType(functionName + "(CAST(2 AS DOUBLE))", "BIGINT NOT NULL");
// Verify that only bits in the integer portion of a decimal value are counted
f.checkScalar(functionName + "(5.23)", "2", "BIGINT NOT NULL");
f.checkScalar(functionName + "(CAST('-9223372036854775808' AS DECIMAL(19, 0)))", "1",
"BIGINT NOT NULL");
f.checkScalar(functionName + "(CAST('-9223372036854775809' AS DECIMAL(19, 0)))", "1",
"BIGINT NOT NULL");
} else {
f.checkType(functionName + "(CAST(x'ad' AS BINARY(1)))", "BIGINT NOT NULL");
f.checkFails("^" + functionName + "(1.2)^",
"Cannot apply '" + functionName + "' to arguments of type '" + functionName
+ "\\(<DECIMAL\\(2, 1\\)>\\)'\\. Supported form\\(s\\): '" + functionName
+ "\\(<INTEGER>\\)'\n"
+ "'" + functionName + "\\(<BINARY>\\)'",
false);
}
};
if (libraries == null) {
consumer.accept(f0);
} else {
f0.forEachLibrary(libraries, consumer);
}
}

@Test void testBitOrAggFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlLibraryOperators.BITOR_AGG, VmName.EXPAND);
Expand Down

0 comments on commit 67405c3

Please sign in to comment.