Skip to content

Commit

Permalink
feat: Sort raw sql fields and conditions [Experiment] (#2988)
Browse files Browse the repository at this point in the history
* feat: Sort raw sql fields and conditions [Experiment]

Context:
To help improve the frequency of raw sql cache hits, this PR is
responsible for sorting column fields and expression conditions in raw
SQL queries. Note: this code should not interferer with the current
operations of Snuba query caching. This is an experiment to investigate
the consistency of the natural ordering of columns and conditions.

Test Plan:
* Provide functionality for sorting, and select a certain percentage
  (defined in runtime config) of raw queries to be part of the
  experiment.
* Exclude queries from an explicit set of referrers to avoid
  some computational overhead.
* Record metrics of cache hits of sorted and unsorted queries.
* Determine whether or not sorting should be implemented for all queries.

* Revise sorting implementation to extend from formatter class

* Move comparison function to parent class and add more tests

* clean up
  • Loading branch information
enochtangg authored Aug 15, 2022
1 parent 4656668 commit f211aa3
Show file tree
Hide file tree
Showing 8 changed files with 572 additions and 13 deletions.
10 changes: 8 additions & 2 deletions snuba/clickhouse/formatter/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,17 @@ def visit_function_call(self, exp: FunctionCall) -> str:
return self._alias(f"({self.__visit_params(exp.parameters)})", exp.alias)

elif exp.function_name == BooleanFunctions.AND:
formatted = (c.accept(self) for c in get_first_level_and_conditions(exp))
current_level = get_first_level_and_conditions(exp)
if self._parsing_context.sort_fields and isinstance(current_level, list):
current_level.sort()
formatted = (c.accept(self) for c in current_level)
return " AND ".join(formatted)

elif exp.function_name == BooleanFunctions.OR:
formatted = (c.accept(self) for c in get_first_level_or_conditions(exp))
current_level = get_first_level_or_conditions(exp)
if self._parsing_context.sort_fields and isinstance(current_level, list):
current_level.sort()
formatted = (c.accept(self) for c in current_level)
return f"({' OR '.join(formatted)})"

ret = f"{escape_identifier(exp.function_name)}({self.__visit_params(exp.parameters)})"
Expand Down
27 changes: 19 additions & 8 deletions snuba/clickhouse/formatter/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
FormattableQuery = Union[Query, CompositeQuery[Table]]


def format_query(query: FormattableQuery) -> FormattedQuery:
def format_query(query: FormattableQuery, sort_fields: bool = False) -> FormattedQuery:
"""
Formats a Clickhouse Query from the AST representation into an
intermediate structure that can either be serialized into a string
Expand All @@ -37,11 +37,17 @@ def format_query(query: FormattableQuery) -> FormattedQuery:
This is the entry point for any type of query, whether simple or
composite.
"""
return FormattedQuery(_format_query_content(query, ClickhouseExpressionFormatter))
return FormattedQuery(
_format_query_content(query, ClickhouseExpressionFormatter, sort_fields)
)


def format_query_anonymized(query: FormattableQuery) -> FormattedQuery:
return FormattedQuery(_format_query_content(query, ExpressionFormatterAnonymized))
def format_query_anonymized(
query: FormattableQuery, sort_fields: bool = False
) -> FormattedQuery:
return FormattedQuery(
_format_query_content(query, ExpressionFormatterAnonymized, sort_fields)
)


class DataSourceFormatter(DataSourceVisitor[FormattedNode, Table]):
Expand Down Expand Up @@ -87,6 +93,7 @@ def _visit_composite_query(
def _format_query_content(
query: FormattableQuery,
expression_formatter_type: Type[ExpressionFormatterBase],
sort_fields: bool = False,
) -> Sequence[FormattedNode]:
"""
Produces the content of the formatted query.
Expand All @@ -95,7 +102,7 @@ def _format_query_content(
Should we have more differences going on we should break this
method into smaller ones.
"""
parsing_context = ParsingContext()
parsing_context = ParsingContext(sort_fields=sort_fields)
formatter = expression_formatter_type(parsing_context)

return [
Expand Down Expand Up @@ -126,9 +133,13 @@ def _format_query_content(
def _format_select(
query: AbstractQuery, formatter: ExpressionVisitor[str]
) -> StringNode:
selected_cols = [
e.expression.accept(formatter) for e in query.get_selected_columns()
]
selected_columns = query.get_selected_columns()
if (
isinstance(formatter, ExpressionFormatterBase)
and formatter._parsing_context.sort_fields
):
selected_columns = sorted(selected_columns)
selected_cols = [e.expression.accept(formatter) for e in selected_columns]
return StringNode(f"SELECT {', '.join(selected_cols)}")


Expand Down
10 changes: 10 additions & 0 deletions snuba/query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ class SelectedExpression:
name: Optional[str]
expression: Expression

def __gt__(self, other: SelectedExpression) -> bool:
if isinstance(self.expression, Column) and isinstance(other.expression, Column):
return self.expression > other.expression
return True

def __lt__(self, other: SelectedExpression) -> bool:
if isinstance(self.expression, Column) and isinstance(other.expression, Column):
return self.expression < other.expression
return True


TExp = TypeVar("TExp", bound=Expression)

Expand Down
146 changes: 145 additions & 1 deletion snuba/query/expressions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import operator
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from datetime import date, datetime
from typing import Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union

from snuba import settings

Expand Down Expand Up @@ -85,6 +86,34 @@ def functional_eq(self, other: Expression) -> bool:
"""
raise NotImplementedError

# These functions are called from the FunctionClass when same class operator overrides are not available
def __gt__(self, other: Any) -> bool:
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
return self.compare_expressions(other, operator.lt)

# Sort by Class first and then alphabetically
# Class order: Literals < Columns < FunctionCalls (nested) < other
def compare_expressions(
self, other: Expression, op: Callable[[Any, Any], Any]
) -> bool:
relate = {operator.lt: True, operator.gt: False}

if isinstance(self, Literal):
return relate[op]
elif isinstance(other, Literal):
return not relate[op]
elif isinstance(self, Column):
return relate[op]
elif isinstance(other, Column):
return not relate[op]
elif isinstance(self, FunctionCall):
return relate[op]
elif isinstance(other, FunctionCall):
return not relate[op]
return True


class ExpressionVisitor(ABC, Generic[TVisited]):
"""
Expand Down Expand Up @@ -281,6 +310,23 @@ def functional_eq(self, other: Expression) -> bool:
return False
return self.value == other.value

# Only sort string literals, otherwise default
def __gt__(self, other: Any) -> bool:
if isinstance(other, Literal):
if isinstance(self.value, str) and isinstance(other.value, str):
return self.value > other.value
else:
return True
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, Literal):
if isinstance(self.value, str) and isinstance(other.value, str):
return self.value < other.value
else:
return True
return self.compare_expressions(other, operator.lt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class Column(Expression):
Expand Down Expand Up @@ -308,6 +354,17 @@ def functional_eq(self, other: Expression) -> bool:
and self.column_name == other.column_name
)

# Override comparison operators for sorting SQL fields and conditions
def __gt__(self, other: Any) -> bool:
if isinstance(other, Column):
return self.column_name > other.column_name
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, Column):
return self.column_name < other.column_name
return self.compare_expressions(other, operator.lt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class SubscriptableReference(Expression):
Expand Down Expand Up @@ -354,6 +411,17 @@ def functional_eq(self, other: Expression) -> bool:
other.key
)

# Sort class by the left side key (literals)
def __gt__(self, other: Any) -> bool:
if isinstance(other, SubscriptableReference):
return self.key > other.key
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, SubscriptableReference):
return self.key < other.key
return self.compare_expressions(other, operator.lt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class FunctionCall(Expression):
Expand Down Expand Up @@ -417,6 +485,41 @@ def functional_eq(self, other: Expression) -> bool:
return False
return True

def operator_helper(
self,
other: FunctionCall,
main_op: Callable[[Any, Any], Any],
other_op: Callable[[Any, Any], Any],
) -> bool:
relate = {operator.lt: True, operator.gt: False}

# Sort by classes if they are different
if not isinstance(other, FunctionCall):
return self.compare_expressions(other, main_op)

# Sort by FunctionCall function names if different
if main_op(self.function_name, other.function_name):
return True
elif other_op(self.function_name, other.function_name):
return False

# Compare the Expressions in the first index in the two FunctionCalls
if self.parameters and other.parameters:
return bool(
main_op(self.parameters[0], other.parameters[0])
) # uses child class overrides or defaults to parent class compare_expressions()
elif self.parameters and not other.parameters:
return not relate[main_op]
elif not self.parameters:
return relate[main_op]
return True

def __gt__(self, other: Any) -> bool:
return self.operator_helper(other, operator.gt, operator.lt)

def __lt__(self, other: Any) -> bool:
return self.operator_helper(other, operator.lt, operator.gt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class CurriedFunctionCall(Expression):
Expand Down Expand Up @@ -476,6 +579,17 @@ def functional_eq(self, other: Expression) -> bool:
return False
return True

# Sort by left side of expression (FunctionCall)
def __gt__(self, other: Any) -> bool:
if isinstance(other, CurriedFunctionCall):
return self.internal_function > other.internal_function
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, CurriedFunctionCall):
return self.internal_function < other.internal_function
return self.compare_expressions(other, operator.lt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class Argument(Expression):
Expand All @@ -500,6 +614,16 @@ def functional_eq(self, other: Expression) -> bool:
return False
return self.name == other.name

def __gt__(self, other: Any) -> bool:
if isinstance(other, Argument):
return self.name > other.name
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, Argument):
return self.name < other.name
return self.compare_expressions(other, operator.lt)


@dataclass(frozen=True, repr=_AUTO_REPR)
class Lambda(Expression):
Expand Down Expand Up @@ -540,3 +664,23 @@ def functional_eq(self, other: Expression) -> bool:
if not self.transformation.functional_eq(other.transformation):
return False
return True

def __gt__(self, other: Any) -> bool:
if isinstance(other, Lambda):
if self.parameters and other.parameters:
return self.parameters[0] > other.parameters[0]
elif self.parameters and not other.parameters:
return True
else:
return False
return self.compare_expressions(other, operator.gt)

def __lt__(self, other: Any) -> bool:
if isinstance(other, Lambda):
if self.parameters and other.parameters:
return self.parameters[0] < other.parameters[0]
elif self.parameters and not other.parameters:
return False
else:
return True
return self.compare_expressions(other, operator.lt)
3 changes: 2 additions & 1 deletion snuba/query/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ class ParsingContext:
alias cache).
"""

def __init__(self) -> None:
def __init__(self, sort_fields: bool = False) -> None:
self.__alias_cache: List[str] = []
self.sort_fields = sort_fields

def add_alias(self, alias: str) -> None:
self.__alias_cache.append(alias)
Expand Down
Loading

0 comments on commit f211aa3

Please sign in to comment.