你想利用矢量[1, 2, 3, ..., n]
—其中n
是輸入矩陣A
沿感興趣—軸的尺寸的平均值,方差和標準差,由矩陣A
本身給予的權重。
具體而言,假設您想要考慮這些垂直軸上的質心統計數據(axis=0
)—這與您所寫的公式相對應。對於固定柱j
,你會做
n = A.shape[0]
r = np.arange(1, n+1)
mu = np.average(r, weights=A[:,j])
var = np.average(r**2, weights=A[:,j]) - mu**2
std = np.sqrt(var)
爲了把所有的計算,針對不同的列在一起,你必須堆疊在一起一堆r
(每列一個)的副本,以形成一個矩陣(我在下面的代碼中調用了R
)。仔細一點,您可以使axis=0
和axis=1
都能正常工作。
import numpy as np
def com_stats(A, axis=0):
A = A.astype(float) # if you are worried about int vs. float
n = A.shape[axis]
m = A.shape[(axis-1)%2]
r = np.arange(1, n+1)
R = np.vstack([r] * m)
if axis == 0:
R = R.T
mu = np.average(R, axis=axis, weights=A)
var = np.average(R**2, axis=axis, weights=A) - mu**2
std = np.sqrt(var)
return mu, var, std
例如,
A = np.array([[1, 1, 0], [1, 2, 1], [1, 1, 1]])
print(A)
# [[1 1 0]
# [1 2 1]
# [1 1 1]]
print(com_stats(A))
# (array([ 2. , 2. , 2.5]), # centre-of-mass mean by column
# array([ 0.66666667, 0.5 , 0.25 ]), # centre-of-mass variance by column
# array([ 0.81649658, 0.70710678, 0.5 ])) # centre-of-mass std by column
編輯:
一個可避免產生的r
內存拷貝通過使用numpy.lib.stride_tricks
構建R
:交換線
R = np.vstack([r] * m)
以上與
from numpy.lib.stride_tricks import as_strided
R = as_strided(r, strides=(0, r.itemsize), shape=(m, n))
所得R
爲(跨距)ndarray
其基礎陣列相同r
的—絕對沒有任何值的複製發生。
from numpy.lib.stride_tricks import as_strided
FMT = '''\
Shape: {}
Strides: {}
Position in memory: {}
Size in memory (bytes): {}
'''
def find_base_nbytes(obj):
if obj.base is not None:
return find_base_nbytes(obj.base)
return obj.nbytes
def stats(obj):
return FMT.format(obj.shape,
obj.strides,
obj.__array_interface__['data'][0],
find_base_nbytes(obj))
n=10
m=1000
r = np.arange(1, n+1)
R = np.vstack([r] * m)
S = as_strided(r, strides=(0, r.itemsize), shape=(m, n))
print(stats(r))
print(stats(R))
print(stats(S))
輸出:
Shape: (10,)
Strides: (8,)
Position in memory: 4299744576
Size in memory (bytes): 80
Shape: (1000, 10)
Strides: (80, 8)
Position in memory: 4304464384
Size in memory (bytes): 80000
Shape: (1000, 10)
Strides: (0, 8)
Position in memory: 4299744576
Size in memory (bytes): 80
感謝this SO answer和this one關於如何獲得跨入ndarray
的底層數組的內存地址和大小的解釋。
我明白了。在我提出問題的解決方案之前,它嘗試了類似的方式......但是,我很擔心創建(可能是巨大的)'R'矩陣。由於'R'無論如何只包含'r'的副本,我們能否擺脫複製? – NichtJens
@NichtJens:見編輯。 –
哦,這是一個非常好的技巧。記憶方面,我想我們現在是平等的。你是否還看到潛在的差異?看來,我們應該介紹我們的解決方案,以提高性能 – NichtJens