2012-05-22 10 views
30

Dla mojego unittest, chcę sprawdzić, czy dwie tablice są identyczne. Zredukowany przykład:porównanie numpy tablice zawierające NaN

a=np.array([1, 2, np.NaN]) 
b=np.array([1, 2, np.NaN]) 
if np.all(a==b): 
    print 'arrays are equal' 

To nie działa, ponieważ nan! = Nan. Jaki jest najlepszy sposób postępowania?

Z góry dziękuję.

Odpowiedz

20

Alternatywnie można użyć numpy.testing.assert_equal lub numpy.testing.assert_array_equal z try/except:

In : import numpy as np 

In : def nan_equal(a,b): 
...:  try: 
...:   np.testing.assert_equal(a,b) 
...:  except AssertionError: 
...:   return False 
...:  return True 

In : a=np.array([1, 2, np.NaN]) 

In : b=np.array([1, 2, np.NaN]) 

In : nan_equal(a,b) 
Out: True 

In : a=np.array([1, 2, np.NaN]) 

In : b=np.array([3, 2, np.NaN]) 

In : nan_equal(a,b) 
Out: False 

Edit

Ponieważ używasz tego dla unittesting, gołe assert (zamiast owijania go, aby uzyskać True/False) może być bardziej naturalne.

+0

Doskonały, to najbardziej eleganckie i wbudowane rozwiązanie. Właśnie dodałem 'np.testing.assert_equal (a, b)' w moim unittest, a jeśli podniesie wyjątek, test się nie powiedzie (brak błędu), a nawet dostaję ładny wydruk z różnicami i niedopasowaniem. Dzięki. – saroele

+3

Należy zauważyć, że to rozwiązanie działa, ponieważ 'numpy.testing.assert_ *' nie stosuje tej samej semantyki python 'assert's. W zwykłym Pythonie powstają wyjątki 'AssertionError' iff' __debug__ ma wartość True' tzn. Jeśli skrypt jest uruchamiany niezoptymalizowany (brak flagi -O), zobacz [docs] (http://docs.python.org/3.3/reference /simple_stmts.html#grammar-token-assert_stmt). Z tego powodu zdecydowanie odradzałbym pakowanie "AssertionErrors" do kontroli przepływu. Oczywiście, ponieważ jesteśmy w pakiecie testowym, najlepszym rozwiązaniem jest pozostawienie samego numpy.testing.assert. –

8

Można użyć NumPy zamaskowanych tablice, maska ​​wartości NaN a następnie użyć numpy.ma.all lub numpy.ma.allclose:

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.all.html

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.allclose.html

Na przykład:

a=np.array([1, 2, np.NaN]) 
b=np.array([1, 2, np.NaN]) 
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True 
+1

dzięki za uczynienie mnie świadomość użycia zamaskowanych tablic. Wolę jednak rozwiązanie Avarisa. – saroele

+0

Powinieneś użyć 'np.ma.masked_where (np.isnan (a), a)' else nie porównasz wartości nieskończonych. –

+0

Testowałem z 'a = np.array ([1, 2, np.NaN])' i 'b = np.array ([1, np.NaN, 2])' które wyraźnie nie są równe i 'np. ma.all (np.ma.masked_invalid (a) == np.ma.masked_invalid (b)) 'nadal zwraca True, więc bądź świadomy tego, jeśli użyjesz tej metody. – tavo

20

nie jestem na pewno jest to najlepsze sposób postępowania, ale jest sposób:

>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all() 
True 
+0

+1 To rozwiązanie wydaje się być nieco szybsze niż rozwiązanie, które zamieściłem z maskowanymi tablicami, chociaż jeśli tworzyłeś maskę do użytku w innych częściach kodu, obciążenie związane z tworzeniem maski stałoby się mniejszym czynnikiem w ogólna efektywność strategii. – JoshAdel

+0

Dzięki.Twoje rozwiązanie działa, ale wolę wbudowany test w numpy, jak sugeruje to Avaris – saroele

+1

Bardzo podoba mi się prostota tego. Wydaje się także szybsze niż rozwiązanie @Avaris. Włączając to do lambdafunction, testowanie za pomocą '% timeit' programu Ipython daje 23,7 μs vs 1,01 ms. – AllanLRH

1

Kiedy stosować powyższą odpowiedź:

((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all() 

To dało mi jakieś erros podczas oceny listę ciągów.

Ten typ jest bardziej ogólna:

def EQUAL(a,b): 
    return ((a == b) | ((a != a) & (b != b))) 
6

Najłatwiej jest skorzystać numpy.allclose() metody, które pozwalają na określenie zachowania, gdy o wartości NAN. Twój przykład będzie wyglądał następująco:

a = np.array([1, 2, np.nan]) 
b = np.array([1, 2, np.nan]) 

if np.allclose(a, b, equal_nan=True): 
    print 'arrays are equal' 

Następnie zostanie wydrukowane arrays are equal.

można znaleźć here odpowiedniej dokumentacji

+0

+1, ponieważ Twoje rozwiązanie nie wymyśla koła. Działa to jednak tylko w przypadku elementów podobnych do liczb. W przeciwnym razie otrzymasz nieprzyjemne "TypeError: ufunc" isfinite "nieobsługiwane dla typów danych wejściowych, a danych wejściowych nie można bezpiecznie przekonać do obsługiwanych typów zgodnie z zasadą" bezpiecznego "castingu. – MLguy

Powiązane problemy