2016-11-09 41 views
3

我必須找到最佳的解決方案> 10^7方程系統有5個方程,每個變量有2個變量(5次測量,找到2個參數,誤差最小的系列)。 下面的代碼(通常用來做曲線擬合)做什麼,我想:有效地解決大量的線性最小二乘法系統

#Create_example_Data 
n = 100 
T_Arm = np.arange(10*n).reshape(-1, 5, 2) 
Erg = np.arange(5*n).reshape(-1, 5) 
m = np.zeros(n) 
c = np.zeros(n) 
#Run 
for counter in xrange(n): 
    m[counter], c[counter] = np.linalg.lstsq(T_Arm[counter, :, :], 
               Erg[counter, :])[0] 

可惜實在是太慢了。有什麼辦法可以顯着提高代碼的速度嗎?我試圖引導它,但我沒有成功。將最後一個解決方案用作初始猜測也可能是一個好主意。使用scipy.optimize.leastsq也沒有加速。

+0

什麼是'Inputlen'?是'n'嗎? – TuanDT

+0

n是等式系統的數目,等於Inputlen,我更正了代碼 – Okapi575

+0

,我認爲它應該是'xrange(n)'而不是'xrange(len(n))',因爲'n'只是一個整數(100在這種情況下) – TuanDT

回答

3

您可以使用稀疏矩陣A存儲T_Arm的(5,2)項在其對角線上,並求解AX = b,其中b是由堆積條目Erg組成的向量。然後用scipy.sparse.linalg.lsqr(A,b)解決系統問題。

爲了構建A和B我使用n = 3的用於可視化目的:

import numpy as np 
import scipy 
from scipy.sparse import bsr_matrix 
n = 3 
col = np.hstack(5 * [np.arange(10 * n/5).reshape(n, 2)]).flatten() 
array([ 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 2., 3., 2., 
     3., 2., 3., 2., 3., 2., 3., 4., 5., 4., 5., 4., 5., 
     4., 5., 4., 5.]) 

row = np.tile(np.arange(10 * n/2), (2, 1)).T.flatten() 
array([ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4., 5., 
     5., 6., 6., 7., 7., 8., 8., 9., 9., 10., 10., 
     11., 11., 12., 12., 13., 13., 14., 14.]) 

A = bsr_matrix((T_Arm[:n].flatten(), (row, col)), shape=(5 * n, 2 * n)) 
A.toarray() 
array([[ 0, 1, 0, 0, 0, 0], 
     [ 2, 3, 0, 0, 0, 0], 
     [ 4, 5, 0, 0, 0, 0], 
     [ 6, 7, 0, 0, 0, 0], 
     [ 8, 9, 0, 0, 0, 0], 
     [ 0, 0, 10, 11, 0, 0], 
     [ 0, 0, 12, 13, 0, 0], 
     [ 0, 0, 14, 15, 0, 0], 
     [ 0, 0, 16, 17, 0, 0], 
     [ 0, 0, 18, 19, 0, 0], 
     [ 0, 0, 0, 0, 20, 21], 
     [ 0, 0, 0, 0, 22, 23], 
     [ 0, 0, 0, 0, 24, 25], 
     [ 0, 0, 0, 0, 26, 27], 
     [ 0, 0, 0, 0, 28, 29]], dtype=int64) 

b = Erg[:n].flatten() 

然後

scipy.sparse.linalg.lsqr(A, b)[0] 
array([ 5.00000000e-01, -1.39548109e-14, 5.00000000e-01, 
     8.71088538e-16, 5.00000000e-01, 2.35398726e-15]) 

編輯:因爲它似乎A沒有在存儲器中作爲巨大:多個上塊稀疏矩陣here

+0

太棒了! – piRSquared

+0

不錯的想法。但數組A包含(n * 5)*(n * 2)個值,它們都是int64,所以對於n = 1000,需要80 MB,對於n = 10000 8GB等等。所有這些字節需要被讀取和寫入。也許使用n = 100並且以塊處理數據比原始解決方案更快。我會試一試。 – Okapi575

+1

@ Okapi575 A不是一個數組(我只使用A.toarray()來表示它),而是一個稀疏矩陣,所以內存消耗會更低。我試圖看看實施是否有效,但我承認我沒有計時。 –