Skip to content

Commit 0794e94

Browse files
authored
Merge pull request #303 from aronbierbaum/add_quantile
Add support for quantile and quantileIf functions
2 parents ac9442a + 0cea38d commit 0794e94

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from sqlalchemy.sql import type_api
55
from sqlalchemy.util import inspect_getfullargspec
66

7+
import clickhouse_sqlalchemy.sql.functions # noqa:F401
8+
79
from ... import types
810

911

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, TypeVar
4+
5+
from sqlalchemy.ext.compiler import compiles
6+
from sqlalchemy.sql import coercions, roles
7+
from sqlalchemy.sql.elements import ColumnElement
8+
from sqlalchemy.sql.functions import GenericFunction
9+
10+
from clickhouse_sqlalchemy import types
11+
12+
if TYPE_CHECKING:
13+
from sqlalchemy.sql._typing import _ColumnExpressionArgument
14+
15+
_T = TypeVar('_T', bound=Any)
16+
17+
18+
class quantile(GenericFunction[_T]):
19+
inherit_cache = True
20+
21+
def __init__(
22+
self, level: float, expr: _ColumnExpressionArgument[Any],
23+
condition: _ColumnExpressionArgument[Any] = None, **kwargs: Any
24+
):
25+
arg: ColumnElement[Any] = coercions.expect(
26+
roles.ExpressionElementRole, expr, apply_propagate_attrs=self
27+
)
28+
29+
args = [arg]
30+
if condition is not None:
31+
condition = coercions.expect(
32+
roles.ExpressionElementRole, condition,
33+
apply_propagate_attrs=self
34+
)
35+
args.append(condition)
36+
37+
self.level = level
38+
39+
if isinstance(arg.type, (types.Decimal, types.Float, types.Int)):
40+
return_type = types.Float64
41+
elif isinstance(arg.type, types.DateTime):
42+
return_type = types.DateTime
43+
elif isinstance(arg.type, types.Date):
44+
return_type = types.Date
45+
else:
46+
return_type = types.Float64
47+
48+
kwargs['type_'] = return_type
49+
kwargs['_parsed_args'] = args
50+
super().__init__(arg, **kwargs)
51+
52+
53+
class quantileIf(quantile[_T]):
54+
inherit_cache = True
55+
56+
57+
@compiles(quantile, 'clickhouse')
58+
@compiles(quantileIf, 'clickhouse')
59+
def compile_quantile(element, compiler, **kwargs):
60+
args_str = compiler.function_argspec(element, **kwargs)
61+
return f'{element.name}({element.level}){args_str}'

docs/features.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,17 @@ Tables created in declarative way have lowercase with words separated by
4646
underscores naming convention. But you can easy set you own via SQLAlchemy
4747
``__tablename__`` attribute.
4848

49-
SQLAlchemy ``func`` proxy for real ClickHouse functions can be also used.
49+
50+
Functions
51+
+++++++++
52+
53+
Many of the ClickHouse functions can be called using the SQLAlchemy ``func``
54+
proxy. A few of aggregate functions require special handling though. There
55+
following functions are supported:
56+
57+
* ``func.quantile(0.5, column1)`` becomes ``quantile(0.5)(column1)``
58+
* ``func.quantileIf(0.5, column1, column2 > 10)`` becomes ``quantileIf(0.5)(column1, column2 > 10)``
59+
5060

5161
Dialect-specific options
5262
++++++++++++++++++++++++

tests/sql/test_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from sqlalchemy import Column, func
2+
3+
from clickhouse_sqlalchemy import types, Table
4+
5+
from tests.testcase import CompilationTestCase
6+
7+
8+
class FunctionTestCase(CompilationTestCase):
9+
table = Table(
10+
't1', CompilationTestCase.metadata(),
11+
Column('x', types.Int32, primary_key=True),
12+
Column('time', types.DateTime)
13+
)
14+
15+
def test_quantile(self):
16+
func0 = func.quantile(0.5, self.table.c.x)
17+
self.assertIsInstance(func0.type, types.Float64)
18+
func1 = func.quantile(0.5, self.table.c.time)
19+
self.assertIsInstance(func1.type, types.DateTime)
20+
self.assertEqual(
21+
self.compile(self.session.query(func0)),
22+
'SELECT quantile(0.5)(t1.x) AS quantile_1 FROM t1'
23+
)
24+
25+
func2 = func.quantileIf(0.5, self.table.c.x, self.table.c.x > 10)
26+
27+
self.assertEqual(
28+
self.compile(
29+
self.session.query(func2)
30+
),
31+
'SELECT quantileIf(0.5)(t1.x, t1.x > %(x_1)s) AS ' +
32+
'"quantileIf_1" FROM t1'
33+
)

0 commit comments

Comments
 (0)