2017-09-14 135 views
5

我已經使用How to apply piecewise linear fit in Python?這個問題中發現的一些代碼來執行具有單個斷點的分段線性近似。具有n個斷點的分段線性擬合

的代碼如下:

from scipy import optimize 
import matplotlib.pyplot as plt 
import numpy as np 
%matplotlib inline 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03]) 

def piecewise_linear(x, x0, y0, k1, k2): 
    return np.piecewise(x, 
         [x < x0], 
         [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 15, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

我試圖找出如何我可以擴展處理ñ斷點。

我試着用下面的代碼來處理2斷點的piecewise_linear()方法,但它不以任何方式改變斷點的值。

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158]) 

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3): 
    return np.piecewise(x, 
         [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
         [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 20, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

任何投入將不勝感激

+0

'''它不work'''是幾乎無用的描述。我也認爲你不能通過curve_fit()來實現這一點,當有多個斷點時(需要線性約束來處理b0 sascha

+0

我認爲,如果我最初在x軸上均勻分佈斷點,那麼找到局部最小值就足以提供一個體面的非最優解。你知道另一個支持線性約束的優化模塊嗎? –

+0

正如我告訴你的,這不僅僅是這個。忽略平滑性和潛在的非凸性,你可以用scipy的更一般的優化函數,即COBYLA和SQSLP(唯一的兩個支持約束)來解決這個問題。我看到的唯一真正的方法是混合整數凸規劃,但軟件是稀疏的(bonmin和couenne是兩個開源解決方案,不適合從python使用; pajarito @ julialang;但是這種方法通常需要一些非 - 簡單的公式)。 – sascha

回答

4

NumPy的有polyfit function這使得它很容易通過一組點找到最佳擬合線:

coefs = npoly.polyfit(xi, yi, 1) 

所以,真正唯一的困難正在找到斷點。對於給定的一組 斷點,通過給定數據找到最合適的線是很簡單的。

因此,而不是試圖一下子找到斷點係數線性部分的 的位置,就足夠了斷點的參數空間 減少了。

由於斷點可以通過它們的整數索引值來指定到x陣列, 參數空間可以被認爲是對N尺寸,其中 N是斷點的數目的整數網格點。

optimize.curve_fit不是一個很好的選擇,因爲這個問題的最小值爲 ,因爲參數空間是整數值。如果您要使用curve_fit, ,算法會調整參數以確定 移動的方向。如果調整小於1個單位,則斷點的x值不會變爲 ,因此錯誤不會更改,因此算法不會收到有關正確移動參數方向的信息 。因此,當參數空間基本上是整數值時,curve_fit 往往會失敗。

一個更好但不是很快的最小化器將是一個強力網格搜索。如果 斷點數很少(參數空間x-值小於 ),這可能就足夠了。如果斷點數量很大和/或參數空間很大,則可能會設置多級粗/細網格搜索(蠻力)。或者,也許有人會建議比蠻力更聰明的最小化...


import numpy as np 
import numpy.polynomial.polynomial as npoly 
from scipy import optimize 
import matplotlib.pyplot as plt 
np.random.seed(2017) 

def f(breakpoints, x, y, fcache): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    if breakpoints not in fcache: 
     total_error = 0 
     for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
      total_error += ((f(xi) - yi)**2).sum() 
     fcache[breakpoints] = total_error 
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints])) 
    return fcache[breakpoints] 

def find_best_piecewise_polynomial(breakpoints, x, y): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    xs = np.split(x, breakpoints) 
    ys = np.split(y, breakpoints) 
    result = [] 
    for xi, yi in zip(xs, ys): 
     if len(xi) < 2: continue 
     coefs = npoly.polyfit(xi, yi, 1) 
     f = npoly.Polynomial(coefs) 
     result.append([f, xi, yi]) 
    return result 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
       18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
       126.14, 140.03, 150, 152, 154, 156, 158]) 
# Add some noise to make it exciting :) 
y += np.random.random(len(y))*10 

num_breakpoints = 2 
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None) 

plt.scatter(x, y, c='blue', s=50) 
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
    x_interval = np.array([xi.min(), xi.max()]) 
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval)) 
    plt.plot(x_interval, f(x_interval), 'ro-') 


plt.show() 

打印

y = poly([ 4.58801083 2.94476604]) , if x in [1.0, 6.0] 
y = poly([-70.36472935 14.37305793]) , if x in [7.0, 15.0] 
y = poly([ 123.24565235 1.94982153]), if x in [16.0, 20.0] 

和情節

enter image description here

+0

很好的答案......我儘可能用'leastsq'和'minim'來嘗試一切,但分段參數'x0'和'x1'只是沒有正確優化 –

+0

完美。謝謝! –