2014-11-08 133 views
1

我在Python中使用了很多argmin和argmax。numpy.argmax/argmin更快的替代方法很慢

不幸的是,該功能非常慢。

我做了一些摸索,我可以找到最好的是在這裏:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array): 
    array = list(array) 
    return array.index(max(array)) 

不幸的是,這個解決方案仍然只有一半快np.max,我想我應該能夠找到和np.max一樣快的東西。

x = np.random.randn(10) 
%timeit np.argmax(x) 
10000 loops, best of 3: 21.8 us per loop 

%timeit fastest_argmax(x)  
10000 loops, best of 3: 20.8 us per loop 

作爲一個說明,我申請這一個數據幀大熊貓GROUPBY

E.G.

%timeit grp2[ 'ODDS' ].agg([ fastest_argmax ]) 
100 loops, best of 3: 8.8 ms per loop 

%timeit grp2[ 'ODDS' ].agg([ np.argmax ]) 
100 loops, best of 3: 11.6 ms per loop 

當數據是這樣的:

grp2[ 'ODDS' ].head() 
Out[60]: 
EVENT_ID SELECTION_ID   
104601100 4367029  682508 3.05 
         682509 3.15 
         682510 3.25 
         682511 3.35 
      5319660  682512 2.04 
         682513 2.08 
         682514 2.10 
         682515 2.12 
         682516 2.14 
      5510310  682520 4.10 
         682521 4.40 
         682522 4.50 
         682523 4.80 
         682524 5.30 
      5559264  682526 5.00 
         682527 5.30 
         682528 5.40 
         682529 5.50 
         682530 5.60 
      5585869  682533 1.96 
         682534 1.97 
         682535 1.98 
         682536 2.02 
         682537 2.04 
      6064546  682540 3.00 
         682541 2.74 
         682542 2.76 
         682543 2.96 
         682544 3.05 
104601200 4916112  682548 2.64 
         682549 2.68 
         682550 2.70 
         682551 2.72 
         682552 2.74 
      5315859  682557 2.90 
         682558 2.92 
         682559 3.05 
         682560 3.10 
         682561 3.15 
      5356995  682564 2.42 
         682565 2.44 
         682566 2.48 
         682567 2.50 
         682568 2.52 
      5465225  682573 1.85 
         682574 1.89 
         682575 1.91 
         682576 1.93 
         682577 1.94 
      5773661  682588 5.00 
         682589 4.40 
         682590 4.90 
         682591 5.10 
      6013187  682592 5.00 
         682593 4.20 
         682594 4.30 
         682595 4.40 
         682596 4.60 
104606300 2489827  683438 4.00 
         683439 3.90 
         683440 3.95 
         683441 4.30 
         683442 4.40 
      3602724  683446 2.16 
         683447 2.32 
Name: ODDS, Length: 65, dtype: float64 
+1

我建議查看他們如何做它的numpy源代碼。這裏有一個鏈接:https://github.com/numpy/numpy/blob/1e22553c37e9112bf2426c1b060275419f906d8d/numpy/core/src/multiarray/calculation.c嘗試模仿他們的方法。 – twasbrillig 2014-11-08 11:32:48

+1

「低效率」甚至沒有描述你所謂的「最快的最大值」。 – 2014-11-08 11:49:00

+0

在我的筆記本電腦''timeit np.argmax(x)''x.shape == 10'上只花了1美元,你的numpy版本是什麼? – HYRY 2014-11-08 12:29:05

回答

3

你可以張貼一些代碼?這是在我的電腦結果:

x = np.random.rand(10000) 
%timeit np.max(x) 
%timeit np.argmax(x) 

輸出:

100000 loops, best of 3: 7.43 µs per loop 
100000 loops, best of 3: 11.5 µs per loop 
+0

你的終端是什麼?庫存控制檯字體無法打印MU字母AFAIK。 – 2014-11-08 15:29:04

+1

@ivan_pozdeev - 大多數現代控制檯模擬器至少有一些(通常是完整的)unicode支持(甚至是xterm)。 – 2014-11-08 17:24:12

7

事實證明,np.argmax極快,但與本地numpy的陣列。隨着國外的數據,幾乎所有的時間都花在轉換:

In [194]: print platform.architecture() 
('64bit', 'WindowsPE') 

In [5]: x = np.random.rand(10000) 
In [57]: l=list(x) 
In [123]: timeit numpy.argmax(x) 
100000 loops, best of 3: 6.55 us per loop 
In [122]: timeit numpy.argmax(l) 
1000 loops, best of 3: 729 us per loop 
In [134]: timeit numpy.array(l) 
1000 loops, best of 3: 716 us per loop 

我叫你的函數「低效」,因爲它首先通過這2次轉換一切列表,然後遍歷(實際上,3次迭代+列表構造) 。

我會這麼建議是這樣的,只有迭代一次:

def imax(seq): 
    it=iter(seq) 
    im=0 
    try: m=it.next() 
    except StopIteration: raise ValueError("the sequence is empty") 
    for i,e in enumerate(it,start=1): 
     if e>m: 
      m=e 
      im=i 
    return im 

但是,你的版本原來是更快,因爲它迭代很多次,但做它在C,不如說Python和代碼。 C的只是更快了 - 甚至在考慮到的大量時間都花在轉換,太:

In [158]: timeit imax(x) 
1000 loops, best of 3: 883 us per loop 
In [159]: timeit fastest_argmax(x) 
1000 loops, best of 3: 575 us per loop 

In [174]: timeit list(x) 
1000 loops, best of 3: 316 us per loop 
In [175]: timeit max(l) 
1000 loops, best of 3: 256 us per loop 
In [181]: timeit l.index(0.99991619010758348) #the greatest number in my case, at index 92 
100000 loops, best of 3: 2.69 us per loop 

所以,關鍵的知識,進一步加快這件事是知道哪些格式的序列數據本質上是(例如,您是否可以省略轉換步驟或使用/編寫該格式的原生功能)。

順便說一句,你可能會得到一些加速通過使用aggregate(max_fn)而不是agg([max_fn])

+0

感謝您的幫助伊凡! – Ginger 2014-11-09 09:57:18