2017-09-14 21 views
5

Użyłem kodu znalezionego w pytaniu How to apply piecewise linear fit in Python?, aby wykonać segmentację liniową z pojedynczym punktem przerwania.Kawałkowo liniowe dopasowanie z n punktami przerwania

Kod jest w następujący sposób:

from scipy import optimize 
import matplotlib.pyplot as plt 
import numpy as np 
%matplotlib inline 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03]) 

def piecewise_linear(x, x0, y0, k1, k2): 
    return np.piecewise(x, 
         [x < x0], 
         [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 15, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

Próbuję dowiedzieć się, w jaki sposób mogę przedłużyć ten obsłużyć n wartości graniczne.

Próbowałem następujący kod dla metody piecewise_linear() do obsługi 2 punktów przerwania, ale w żaden sposób nie zmienia wartości punktów przerwania.

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158]) 

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3): 
    return np.piecewise(x, 
         [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
         [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 20, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

Każde wejście byłoby bardzo mile widziane

+0

'' 'To nie work''' jest prawie bezużyteczny opis. Myślę też, że nie osiągniesz tego za pomocą curve_fit(), która staje się bardziej złożona, gdy istnieje wiele punktów przerwania (wymagałoby ograniczeń liniowych do obsługi b0 sascha

+0

Myślę, że jeśli najpierw rozprowadzę punkty przerwania równomiernie na osi X, to znalezienie minimalnych lokalnych będzie wystarczające, aby zapewnić przyzwoite, nieoptymalne rozwiązanie. Czy znasz inny moduł optymalizacji, który obsługuje ograniczenia liniowe? –

+0

Jak wam powiedziałem, nie chodzi tylko o to. Ignorując gładkość i potencjalny non-convexity, możesz rozwiązać ten problem za pomocą bardziej ogólnych funkcji optymalizacji scipy, mianowicie COBYLA i SQSLP (jedyne dwa ograniczenia wspierające). Jedyne prawdziwe podejście, jakie widzę, to programowanie wypukłe z integracją całkowitą, ale oprogramowanie jest rzadkie (bonmin i couenne są dwoma rozwiązaniami open-source, których nie można używać z python; pajarito @ julialang, ale to podejście w ogóle wymaga pewnych trywialne sformułowanie). – sascha

Odpowiedz

4

NumPy ma polyfit function co sprawia, że ​​bardzo łatwo znaleźć najlepsze dopasowanie linię przez zbiór punktów:

coefs = npoly.polyfit(xi, yi, 1) 

Tak naprawdę jedyną trudność znajduje punkty przerwania. Dla danego zestawu wartości granicznych znalezienie linii najlepiej dopasowanych przez dane jest banalne.

Więc zamiast próbować znaleźć lokalizację pułapki i współczynników z elementów liniowych wszystkie naraz, wystarczy zminimalizować nad przestrzeń parametrów z pułapki.

Ponieważ punkty przerwania może być określona przez ich całkowitą indeksu do tablicy x, przestrzeni parametrów mogą być traktowane jako punkty na siatki całkowitoliczbowej N wymiarach, gdzie N jest liczba przerwań.

optimize.curve_fit nie jest dobrym wyborem jako minimalizator dla tego problemu , ponieważ przestrzeń parametru jest wartością całkowitą. Jeśli użyjesz curve_fit, algorytm dostraja parametry, aby określić, w którym kierunku przenieść . Jeśli korekta jest mniejsza niż 1 jednostka, wartości x punktów przerwania nie zmieniają się, więc błąd się nie zmienia, więc algorytm nie otrzymuje informacji o poprawnym kierunku przesuwania parametrów. Stąd curve_fit ma tendencję do niepowodzenia, gdy przestrzeń parametru jest w istocie liczbą całkowitą.

Lepszy, ale niezbyt szybki, minimizer będzie siatkowym wyszukiwaniem siły. Jeśli liczba punktów przerwania jest mała (a przestrzeń parametru o wartości x jest mała), może to wystarczyć. Jeśli liczba punktów przerwania jest duża i/lub przestrzeń parametru jest duża, być może skonfiguruj wielostopniowe zgrubne/dokładne przeszukiwanie siatki o siatce brutalnej (ang. Brute-force, ). A może ktoś zaproponuje mądrzejszy minimizer niż brutalną siłę ...


import numpy as np 
import numpy.polynomial.polynomial as npoly 
from scipy import optimize 
import matplotlib.pyplot as plt 
np.random.seed(2017) 

def f(breakpoints, x, y, fcache): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    if breakpoints not in fcache: 
     total_error = 0 
     for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
      total_error += ((f(xi) - yi)**2).sum() 
     fcache[breakpoints] = total_error 
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints])) 
    return fcache[breakpoints] 

def find_best_piecewise_polynomial(breakpoints, x, y): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    xs = np.split(x, breakpoints) 
    ys = np.split(y, breakpoints) 
    result = [] 
    for xi, yi in zip(xs, ys): 
     if len(xi) < 2: continue 
     coefs = npoly.polyfit(xi, yi, 1) 
     f = npoly.Polynomial(coefs) 
     result.append([f, xi, yi]) 
    return result 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
       18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
       126.14, 140.03, 150, 152, 154, 156, 158]) 
# Add some noise to make it exciting :) 
y += np.random.random(len(y))*10 

num_breakpoints = 2 
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None) 

plt.scatter(x, y, c='blue', s=50) 
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
    x_interval = np.array([xi.min(), xi.max()]) 
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval)) 
    plt.plot(x_interval, f(x_interval), 'ro-') 


plt.show() 

drukuje

y = poly([ 4.58801083 2.94476604]) , if x in [1.0, 6.0] 
y = poly([-70.36472935 14.37305793]) , if x in [7.0, 15.0] 
y = poly([ 123.24565235 1.94982153]), if x in [16.0, 20.0] 

i działek

enter image description here

+0

Świetna odpowiedź ... Próbowałem wszystkiego, co mogłem, używając "minimumsq" i "minimize", ale parametry "x0" i "x1" po prostu nie zostały poprawnie zoptymalizowane. –

+0

Doskonały. Dziękuję Ci! –

Powiązane problemy