2016-10-13 9 views
5

próbowałem przedłużyć scikit-learn na RidgeCV modelu za pomocą dziedziczenia:Dziedziczenie z scikit-learn na LassoCV modelu

from sklearn.linear_model import RidgeCV, LassoCV 

class Extended(RidgeCV): 
    def __init__(self, *args, **kwargs): 
     super(Extended, self).__init__(*args, **kwargs) 

    def example(self): 
     print 'Foo' 


x = [[1,0],[2,0],[3,0],[4,0], [30, 1]] 
y = [2,4,6,8, 60] 
model = Extended(alphas = [float(a)/1000.0 for a in range(1, 10000)]) 
model.fit(x,y) 
print model.predict([[5,1]]) 

pracował idealnie w porządku, ale kiedy próbowałem dziedziczą LassoCV, że przyniosły następujące traceback:

Traceback (most recent call last): 
    File "C:/Python27/so.py", line 14, in <module> 
    model.fit(x,y) 
    File "C:\Python27\lib\site-packages\sklearn\linear_model\coordinate_descent.py", line 1098, in fit 
    path_params = self.get_params() 
    File "C:\Python27\lib\site-packages\sklearn\base.py", line 214, in get_params 
    for key in self._get_param_names(): 
    File "C:\Python27\lib\site-packages\sklearn\base.py", line 195, in _get_param_names 
    % (cls, init_signature)) 
RuntimeError: scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class '__main__.Extended'> with constructor (<self>, *args, **kwargs) doesn't follow this convention. 

Czy ktoś może wyjaśnić, jak to naprawić?

Odpowiedz

5

Prawdopodobnie chcesz uczynić zgodny z scikit-learn model, aby móc go dalej używać z dostępnym działaniem naukowym scikit. Jeśli nie - musisz przeczytać ten pierwszy: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

Krótko: scikit-learn ma wiele funkcji, takich jak estymatora klonowania (clone() funkcji), meta algorytmów jak GridSearch, Pipeline, walidacji Cross. Wszystkie te rzeczy muszą być w stanie uzyskać wartości pól wewnątrz twojego estymatora i zmienić wartość tych pól (na przykład GridSearch musi zmienić parametry wewnątrz twojego estymatora przed każdą oceną), jak parametr alpha w SGDClassifier. Aby zmienić wartość jakiegoś parametru, musi znać jego nazwę. Aby uzyskać nazwy wszystkich pól w każdej klasie klasy get_params z klasy BaseEstimator (Które dziedziczy się niejawnie) wymaga określenia wszystkich parametrów w metodzie klasy __init__, ponieważ łatwo jest introspekować wszystkie nazwy parametrów metody __init__ (Spójrz na BaseEstimator , to jest klasa, która rzuca ten błąd).

Więc po prostu chce, aby usunąć wszystkie varargs jak

*args, **kwargs 

z __init__ podpisu. Musisz podać wszystkie parametry swojego modelu w podpisie __init__ i zainicjować wszystkie wewnętrzne pola obiektu.

Oto przykład __init__ metody SGDClassifier, który jest dziedziczony z BaseSGDClassifier:

def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15, 
      fit_intercept=True, n_iter=5, shuffle=True, verbose=0, 
      epsilon=DEFAULT_EPSILON, n_jobs=1, random_state=None, 
      learning_rate="optimal", eta0=0.0, power_t=0.5, 
      class_weight=None, warm_start=False, average=False): 
    super(SGDClassifier, self).__init__(
     loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio, 
     fit_intercept=fit_intercept, n_iter=n_iter, shuffle=shuffle, 
     verbose=verbose, epsilon=epsilon, n_jobs=n_jobs, 
     random_state=random_state, learning_rate=learning_rate, eta0=eta0, 
     power_t=power_t, class_weight=class_weight, warm_start=warm_start, average=average) 
+1

byłem hopimg że nie było bardziej eleganckie rozwiązanie, ale jak widać, jest to najprostszy. Dziękuję Ci bardzo! –

Powiązane problemy