2016-08-23 63 views
3

我試圖向量化某個加權總和,但無法弄清楚如何去做。我在下面創建了一個簡單的最小工作示例。我猜這個解決方案涉及到bsxfun或重塑和克羅內克產品,但我仍然沒有設法使它工作。三重加權總和

rng(1); 
N = 200; 
T1 = 5; 
T2 = 7; 
T3 = 10; 


A = rand(N,T1,T2,T3); 
w1 = rand(T1,1); 
w2 = rand(T2,1); 
w3 = rand(T3,1); 

B = zeros(N,1); 

for i = 1:N 
for j1=1:T1 
    for j2=1:T2 
    for j3=1:T3 
    B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3); 
    end 
    end 
end 
end 

A = B; 

對於二維情況,有一個智能答案here

+0

你需要推廣?因爲如果是這樣的話,我會把你的N,T1,T2,T3換成一個數組。 –

+0

我其實只是想要三維的情況。但泛化可能對其他人有用:) – phdstudent

+0

以下概括:) –

回答

5

您可以使用額外的乘法修改前一個答案的w1 * w2'網格,然後再乘以w3。然後,您可以再次使用矩陣乘法與A的「拼合」版本相乘。

W = reshape(w1 * w2.', [], 1) * w3.'; 
B = reshape(A, size(A, 1), []) * W(:); 

你可以換權的創建到它自己的功能,使這個推廣到N權重。由於這使用遞歸,因此N僅限於當前的遞歸限制(默認爲500)。

function W = createWeights(W, varargin) 
    if numel(varargin) > 0 
     W = createWeights(W(:) * varargin{1}(:).', varargin{2:end}); 
    end 
end 

而且隨着使用它:

W = createWeights(w1, w2, w3); 
B = reshape(A, size(A, 1), []) * W(:); 

更新

使用的@ CKT的很好的建議,使用kron一部分,我們可以修改createWeights只是一點點。

function W = createWeights(W, varargin) 
    if numel(varargin) > 0 
     W = createWeights(kron(varargin{1}, W), varargin{2:end}); 
    end 
end 
+0

@Suever基準頂部! :) –

1

這是一個道理:

ww1 = repmat (permute (w1, [4, 1, 2, 3]), [N, 1, T2, T3]); 
ww2 = repmat (permute (w2, [3, 4, 1, 2]), [N, T1, 1, T3]); 
ww3 = repmat (permute (w3, [2, 3, 4, 1]), [N, T1, T2, 1 ]); 

B = ww1 .* ww2 .* ww3 .* A; 
B = sum (B(:,:), 2) 

您可以通過在首位適當的尺寸創建w1w2w3避免permute。此外,您可以使用bsxfun而不是repmat來獲得額外的性能,我只是在此處顯示邏輯,而repmat更容易遵循。

編輯:廣義版本任意輸入尺寸:

Dims = {N, T1, T2, T3}; % add T4, T5, T6, etc as appropriate 
Params = cell (1, length (Dims)); 

Params{1} = rand (Dims{:}); 
for n = 2 : length (Dims) 
    DimSubscripts = ones (1, length (Dims)); DimSubscripts(n) = Dims{n}; 
    RepSubscripts = [Dims{:}]; RepSubscripts(n) = 1; 
    Params{n} = repmat (rand (DimSubscripts), RepSubscripts); 
end 

B = times (Params{:}); 
B = sum (B(:,:), 2) 
1

同樣,你不能概括這一點,以及對ND,除非你做了一些功能來構造克羅內克產品載體,但如何

A = reshape(A, N, []) * kron(w3, kron(w2, w1)); 
1

如果我們想反正有功能的途徑,並偏袒優雅/簡潔的性能,然後再考慮這一點:

function B = weightReduce(A, varargin) 

    B = A; 
    for i = length(varargin):-1:1 
     N = length(varargin{i}); 
     B = reshape(B, [], N) * varargin{i}; 
    end 

end 

這是性能的比較,我看到:

tic; 
for i = 1:10000 
    W = createWeights(w1,w2,w3); 
    B = reshape(A, size(A,1), [])*W(:); 
end 
toc 
Elapsed time is 0.920821 seconds. 
tic; 
for i = 1:10000 
    B2 = weightReduce(A, w1, w2, w3); 
end 
toc 
Elapsed time is 0.484470 seconds.