2016-11-09 63 views
0

這是我如何獲得我的ND數據(func是IRL不是矢量化):取結果的一維列表,並將其轉換爲ND xarray.DataArray

import numpy 
import xarray 
import itertools 

xs = numpy.linspace(0, 10, 100) 
ys = numpy.linspace(0, 0.1, 20) 
zs = numpy.linspace(0, 5, 200) 

def func(x, y, z): 
    return x * y/z 

vals = list(itertools.product(xs, ys, zs)) 
result = [func(x, y, z) for x, y, z in vals] 

我有一種感覺,我做什麼可以簡化。我想把這個放在xarray.DataArray而不改變數據。然而,這是我現在該怎麼辦呢:

arr = np.array(result).reshape(len(xs), len(ys), len(zs)) 
da = xarray.DataArray(arr, coords=[('x', xs), ('y', ys), ('z', zs)]) 

這是一個簡單的例子,但通常我〜10D的數據,我通過映射itertools.product(並行)獲得工作。

我的問題:我怎麼能做到這一點沒有重塑我的數據,並通過使用vals並沒有採取的xsyszs的長度?

以類似的方式對你做什麼用:

index = pandas.MultiIndex.from_tuples(vals, names=['x', 'y', 'z']) 
df = pandas.DataFrame(result, columns=['result'], index=index) 

編輯: 這是我如何解決它,通過@hpaulj靈感的答案,謝謝!

import numpy 
import xarray 
import itertools 

coords = dict(x=numpy.linspace(0, 10, 100), 
       y=numpy.linspace(0, 0.1, 20), 
       z=numpy.linspace(0, 5, 200)) 

def func(x, y, z): 
    return x * y/z 

result = [func(x, y, z) for x, y, z in itertools.product(*coords.values())] 

xarray.DataArray(numpy.reshape(result, [len(i) for i in coords.values()]), coords=coords) 
+0

注意你編輯過的問題,不好。低於原始問題的答案。 – Balzola

+0

我的問題的本質並沒有改變。我只是添加了一個維度,使它看起來不那麼瑣碎。請注意,我說:「我使用〜10D數據」。 – johnbaltis

+0

它確實改變了問題的本質!另外你被z分割的事實肯定會改變這個問題。自己決定是否需要任何幫助。 – Balzola

回答

1

有經驗numpy用戶傾向於專注於去除迭代步驟。因此我們放大了您的result計算,並將reshape視爲一件小事。因此,迄今爲止的答案都集中在廣播和計算你的功能。

但我開始懷疑,什麼是真正困擾你的是,如果你有10種這樣的尺寸

reshape(len(xs), len(ys), len(zs)) 

可能變得笨拙,不只是3。它沒有那麼多的計算速度,但這種努力要求輸入len(..) 10次。或者可能是代碼看起來很醜。

無論如何,這裏有一種繞過所有鍵入的方法。關鍵是用一個簡單的列表理解收集維數組列表中的

In [495]: dims = [np.linspace(0,10,4), np.linspace(0,.1,3), np.linspace(0,5,5)] 
In [496]: from itertools import product 
In [497]: vals = list(product(*dims)) 
In [498]: len(vals) 
Out[498]: 60 
In [499]: result = [sum(ijk) for ijk in vals] # a simple func 

現在剛剛拿到len's

In [501]: arr=np.array(result).reshape([len(i) for i in dims]) 
In [502]: arr.shape 
Out[502]: (4, 3, 5) 

另一種可能性是將在列表中的linspace參數在一開始。

In [504]: ldims=[4,3,5] 
In [505]: ends=[10,.1,5] 
In [506]: dims=[np.linspace(0,e,l) for e,l in zip(ends, ldims)] 
In [507]: vals = list(product(*dims)) 
In [508]: result=[sum(ijk) for ijk in vals] 
In [509]: arr=np.array(result).reshape(ldims) 

reshape本身不是一個昂貴的操作。通常它會創建一個視圖,這是您可以對數組執行的最快事情之一。

@Divakar在他刪除的答案中暗示了這種解決方案,*np.meshgrid(*A)替代了您的product(xs,ys)

順便說一句,我的答案不涉及xarray要麼 - 因爲我沒有安裝該軟件包。我假設你在將3d形狀的arr傳遞給它時知道自己在做什麼,而不是更長的1d陣列。查看標籤號碼,關於numpy的5k追隨者,對於xarray的23追隨者。

xarraycoords參數也可以從dims(帶有附加名稱列表)構建。

如果這個答案不符合你的喜好,我會建議關閉這個問題,然後用xarray標籤開始一個新問題。這樣你就不會吸引衆多的蒼蠅numpy

+0

謝謝!這啓發了我:)我認爲'dims''dict'甚至會更好,因爲我將能夠使用'coords'的鍵,並且將保留列表/ numpy數組的名稱。 – johnbaltis

+0

或者使用單獨的'dims'和'coords'參數。我花了一分鐘時間查看API,http://xarray.pydata.org/en/stable/generated/xarray.DataArray。html – hpaulj

+0

我編輯了我的答案並添加了一個潛在的解決方案:) – johnbaltis

-1

我的回答:[刪除,因爲問題意外更改]

+0

謝謝,但這不是問題的答案。我正在尋找一種無需重塑的方法。另外,你的答案根本不涉及「xarray」。 – johnbaltis

+0

重塑一個numpy數組是一個微不足道的操作。 – hpaulj

0

第二編輯我已經忘記了einsum!如果你可以折磨你的功能,以適應這將是更快(下面使用timeit爲1.5ms)

result = np.einsum('i,j,k', xs, ys, 1.0/zs) 

你需要重塑和廣播相同形狀的陣列。正如巴爾佐拉所說,如果它在每個方向上都是10D和100(10 ** 20個元素),這將非常大。正如hpaulj所說,重塑一個numpy數組通常是微不足道的,在這種情況下,雖然廣播確實需要一些工作。但比itertools.product()方法少得多。對於你的例子

import numpy as np 

xs = np.linspace(0, 10, 100) 
ys = np.linspace(0, 0.1, 20) 
zs = np.linspace(0.1, 5, 200) 

xn, yn, zn = len(xs), len(ys), len(zs) 

xs_b = np.broadcast_to(xs.reshape(xn, 1, 1), (xn, yn, zn)) 
ys_b = np.broadcast_to(ys.reshape(1, yn, 1), (xn, yn, zn)) 
zs_b = np.broadcast_to(zs.reshape(1, 1, zn), (xn, yn, zn)) 

result = xs_b * ys_b/zs_b 

使用timeit如下我得到numpy計算爲4ms和itertools方法150ms。我認爲更多維度的差異會更大。

import timeit 

init = ''' 
import itertools 
import numpy as np 

def func(x, y, z): 
    return x * y/z 

xs = np.linspace(0, 10, 100) 
ys = np.linspace(0, 0.1, 20) 
zs = np.linspace(0.1, 5, 200) 

xn, yn, zn = len(xs), len(ys), len(zs) 
''' 
funcs = [''' 
xs_b = np.broadcast_to(xs.reshape(xn, 1, 1), (xn, yn, zn)) 
ys_b = np.broadcast_to(ys.reshape(1, yn, 1), (xn, yn, zn)) 
zs_b = np.broadcast_to(zs.reshape(1, 1, zn), (xn, yn, zn)) 

result = xs_b * ys_b/zs_b 
''',''' 
vals = list(itertools.product(xs, ys, zs)) 
result = [func(x, y, z) for x, y, z in vals] 
'''] 

for f in funcs: 
    print(timeit.timeit(f, setup=init, number=100)) 

編輯PS。我改變了你的zs,通過除以零來防止numpy警告,因爲這可能影響了時間比較。

相關問題