2017-04-12 105 views
4

sqlalchemypostgresql DB)中,我想創建一個有界求和函數,因爲缺少更好的術語。目標是在定義的範圍內創建運行總數。帶邊界的SQLAlchemy求和函數

目前,我有一些非常適合計算沒有邊界的跑步總數的東西。事情是這樣的:

from sqlalchemy.sql import func 

foos = (
    db.query(
     Foo.id, 
     Foo.points, 
     Foo.timestamp, 
     func.sum(Foo.points).over(order_by=Foo.timestamp).label('running_total') 
    ) 
    .filter(...) 
    .all() 
) 

不過,我想能夠約束這種運行總量總是在特定的範圍內,讓我們說[-100, 100]。所以我們會得到類似的東西(見running_total):

{'timestamp': 1, 'points': 75, 'running_total': 75} 
{'timestamp': 2, 'points': 50, 'running_total': 100} 
{'timestamp': 3, 'points': -100, 'running_total': 0} 
{'timestamp': 4, 'points': -50, 'running_total': -50} 
{'timestamp': 5, 'points': -75, 'running_total': -100} 

任何想法?

回答

2

我最初的答案是錯的,請參閱下面編輯:

在原始的SQL語句,你會做到這一點使用greatest & least functions

事情是這樣的:

LEAST(GREATEST(SUM(myfield) OVER (window_clause), lower_bound), upper_bound) 

SQLAlchemy的表達式語言允許一個兩個寫幾乎相同

import sqlalchemy as sa 
import sqlalchemy.ext.declarative as dec 
base = dec.declarative_base() 

class Foo(base): 
    __tablename__ = 'foo' 
    id = sa.Column(sa.Integer, primary_key=True) 
    points = sa.Column(sa.Integer, nullable=False) 
    timestamp = sa.Column('tstamp', sa.Integer) 

upper_, lower_ = 100, -100 
win_expr = func.sum(Foo.points).over(order_by=Foo.timestamp) 
bound_expr = sa.func.least(sa.func.greatest(win_expr, lower_), upper_).label('bounded_running_total') 

stmt = sa.select([Foo.id, Foo.points, Foo.timestamp, bound_expr]) 

str(stmt) 
# prints output: 
# SELECT foo.id, foo.points, foo.tstamp, least(greatest(sum(foo.points) OVER (ORDER BY foo.tstamp), :greatest_1), :least_1) AS bounded_running_total 
# FROM foo' 


# alternatively using session.query you can also fetch results 

from sqlalchemy.orm sessionmaker 
DB = sessionmaker() 
db = DB() 
foos_stmt = dm.query(Foo.id, Foo.points, Foo.timestamp, bound_expr).filter(...) 
str(foos_stmt) 
# prints output: 
# SELECT foo.id, foo.points, foo.tstamp, least(greatest(sum(foo.points) OVER (ORDER BY foo.tstamp), :greatest_1), :least_1) AS bounded_running_total 
# FROM foo' 

foos = foos_stmt.all() 

編輯作爲用戶@pozs在評論中指出,上述不產生預期的結果。

@pozs提出了兩種替代方法。在這裏,我調整了第一個遞歸查詢方法,通過sqlalchemy構建。

import sqlalchemy as sa 
import sqlalchemy.ext.declarative as dec 
import sqlalchemy.orm as orm 
base = dec.declarative_base() 

class Foo(base): 
    __tablename__ = 'foo' 
    id = sa.Column(sa.Integer, primary_key=True) 
    points = sa.Column(sa.Integer, nullable=False) 
    timestamp = sa.Column('tstamp', sa.Integer) 

upper_, lower_ = 100, -100 
t = sa.select([ 
    Foo.timestamp, 
    Foo.points, 
    Foo.points.label('bounded_running_sum') 
]).order_by(Foo.timestamp).limit(1).cte('t', recursive=True) 

t_aliased = orm.aliased(t, name='ta') 

bounded_sum = t.union_all(
    sa.select([ 
    Foo.timestamp, 
    Foo.points, 
    sa.func.greatest(sa.func.least(Foo.points + t_aliased.c.bounded_running_sum, upper_), lower_) 
    ]).order_by(Foo.timestamp).limit(1) 
) 
stmt = sa.select([bounded_sum]) 

# inspect the query: 
from sqlalchemy.dialects import postgresql 
print(stmt.compile(dialect=postgresql.dialect(), 
        compile_kwargs={'literal_binds': True})) 
# prints output: 
# WITH RECURSIVE t(tstamp, points, bounded_running_sum) AS 
# ((SELECT foo.tstamp, foo.points, foo.points AS bounded_running_sum 
# FROM foo ORDER BY foo.tstamp 
# LIMIT 1) UNION ALL (SELECT foo.tstamp, foo.points, greatest(least(foo.points + ta.bounded_running_sum, 100), -100) AS greatest_1 
# FROM foo, t AS ta ORDER BY foo.tstamp 
# LIMIT 1)) 
# SELECT t.tstamp, t.points, t.bounded_running_sum 
# FROM t 

我用這個link from the documentation作爲參考來構造上述情況,其也突出瞭如何一個可使用的會話,而不是與遞歸CTE的

工作這將是純SQLAlchemy的方法來生成要求的結果。

@pozs建議的第二種方法也可以通過sqlalchemy使用。

該解決方案必須能成爲其中的section from the documentation

+2

這不會產生所需的結果 – pozs

+0

@pozs,你說得對。我應該刪除答案,還是因歷史原因保留答案? –

+0

或者你可以糾正它。如果你知道如何更好/不同地/在SQLAlchemy中,我也很感興趣。 – pozs

5

變種不幸的是,沒有內置的份額可幫助您實現與窗口函數調用您的預計產量。

你可以得到預期的輸出與手動計算行一個接一個用recursive CTE

with recursive t as (
    (select *, points running_total 
    from  foo 
    order by timestamp 
    limit 1) 
    union all 
    (select foo.*, least(greatest(t.running_total + foo.points, -100), 100) 
    from  foo, t 
    where foo.timestamp > t.timestamp 
    order by foo.timestamp 
    limit 1) 
) 
select timestamp, 
     points, 
     running_total 
from t; 

不幸的是,這將是非常難以實現與SQLAlchemy的。

你的另一種選擇是,以write a custom aggregate您的特定需求,如:

create function bounded_add(int_state anyelement, next_value anyelement, next_min anyelement, next_max anyelement) 
    returns anyelement 
    immutable 
    language sql 
as $func$ 
    select least(greatest(int_state + next_value, next_min), next_max); 
$func$; 

create aggregate bounded_sum(next_value anyelement, next_min anyelement, next_max anyelement) 
(
    sfunc = bounded_add, 
    stype = anyelement, 
    initcond = '0' 
); 

有了這個,你只需要更換您的來電sum要到bounded_sum呼叫:

select timestamp, 
     points, 
     bounded_sum(points, -100.0, 100.0) over (order by timestamp) running_total 
from foo; 

這後一種解決方案也可能會更好地擴展。

http://rextester.com/LKCUK93113