2014-04-29 12 views
9

Zauważyłem niespójne zachowanie w numpy.dot, gdy zaangażowane są zer i liczby nan.Błąd Numpy.dot? Niekonsekwentne zachowanie NaN

Czy ktokolwiek może to zrozumieć? Czy to błąd? Czy jest to specyficzne dla funkcji dot?

Używam numpy v1.6.1, 64bit, działa na Linuksie (również testowane na v1.6.2). Testowałem także na v1.8.0 na Windowsie 32-bitowym (nie mogę stwierdzić, czy różnice wynikają z wersji lub systemu operacyjnego czy arch.).

from numpy import * 
0*nan, nan*0 
=> (nan, nan) # makes sense 

#1 
a = array([[0]]) 
b = array([[nan]]) 
dot(a, b) 
=> array([[ nan]]) # OK 

#2 -- adding a value to b. the first value in the result is 
#  not expected to be affected. 
a = array([[0]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ 0., 0.]]) # EXPECTED : array([[ nan, 0.]]) 
# (also happens in 1.6.2 and 1.8.0) 
# Also, as @Bill noted, a*b works as expected, but not dot(a,b) 

#3 -- changing a from 0 to 1, the first value in the result is 
#  not expected to be affected. 
a = array([[1]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ nan, 1.]]) # OK 

#4 -- changing shape of a, changes nan in result 
a = array([[0],[0]]) 
b = array([[ nan, 1.]]) 
dot(a, b) 
=> array([[ 0., 0.], [ 0., 0.]]) # EXPECTED : array([[ nan, 0.], [ nan, 0.]]) 
# (works as expected in 1.6.2 and 1.8.0) 

Case # 4 wydaje się działać poprawnie w v1.6.2 i v1.8.0, ale nie przypadku # 2 ...


EDIT: @seberg podkreślić, że jest to kwestia blas , więc tutaj jest informacji o instalacji Blas znalazłem uruchamiając from numpy.distutils.system_info import get_info; get_info('blas_opt'):

1.6.1 linux 64bit 
/usr/lib/python2.7/dist-packages/numpy/distutils/system_info.py:1423: UserWarning: 
    Atlas (http://math-atlas.sourceforge.net/) libraries not found. 
    Directories to search for the libraries can be specified in the 
    numpy/distutils/site.cfg file (section [atlas]) or by setting 
    the ATLAS environment variable. 
    warnings.warn(AtlasNotFoundError.__doc__) 
{'libraries': ['blas'], 'library_dirs': ['/usr/lib'], 'language': 'f77', 'define_macros': [('NO_ATLAS_INFO', 1)]} 

1.8.0 windows 32bit (anaconda) 
c:\Anaconda\Lib\site-packages\numpy\distutils\system_info.py:1534: UserWarning: 
    Blas (http://www.netlib.org/blas/) sources not found. 
    Directories to search for the sources can be specified in the 
    numpy/distutils/site.cfg file (section [blas_src]) or by setting 
    the BLAS_SRC environment variable. 
warnings.warn(BlasSrcNotFoundError.__doc__) 
{} 

(ja osobiście nie wiem, co o tym myśleć)

+1

Jest interesujący dla przypadku 2, "a * b" daje pożądany wynik, ale nie ma wartości 'np.dot (a, b)'. – wflynny

+3

Wynik kropki zależy od używanej biblioteki blas. Na przykład widzę to samo z openblas (ale nie z atlasem), więc albo to nie jest określone, albo błąd w bibliotece blas. Mnożenie jest niezwiązane naprawdę ... – seberg

+2

Hmm, spróbuj 'from numpy.distutils.system_info import get_info; get_info ('blas_opt') ' – seberg

Odpowiedz

3

Myślę, że jak sugerował Seberg, jest to problem związany z biblioteką BLAS. Jeśli przyjrzeć się, w jaki sposób numpy.dot jest zaimplementowany here i here, znajdziesz wywołanie cblas_dgemm() dla przypadku macierzy podwójnej precyzji macierzy czasu macierzy.

Ten program w języku C, który odtwarza niektóre przykłady, daje takie same wyniki, gdy używa "zwykłego" BLASa, i właściwą odpowiedź podczas korzystania z ATLAS.

#include <stdio.h> 
#include <math.h> 

#include "cblas.h" 

void onebyone(double a11, double b11, double expectc11) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=1; 
    int K=1; 
    double alpha=1.0; 
    double A[1]={a11}; 
    int lda=1; 
    double B[1]={b11}; 
    int ldb=1; 
    double beta=0.0; 
    double C[1]; 
    int ldc=1; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g]) -> [%.18g]; expected [%.18g]\n",a11,b11,C[0],expectc11); 
} 

void onebytwo(double a11, double b11, double b12, 
       double expectc11, double expectc12) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=2; 
    int K=1; 
    double alpha=1.0; 
    double A[]={a11}; 
    int lda=1; 
    double B[2]={b11,b12}; 
    int ldb=2; 
    double beta=0.0; 
    double C[2]; 
    int ldc=2; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g, %.18g]) -> [%.18g, %.18g]; expected [%.18g, %.18g]\n", 
     a11,b11,b12,C[0],C[1],expectc11,expectc12); 
} 

int 
main() 
{ 
    onebyone(0, 0, 0); 
    onebyone(2, 3, 6); 
    onebyone(NAN, 0, NAN); 
    onebyone(0, NAN, NAN); 
    onebytwo(0, 0,0, 0,0); 
    onebytwo(2, 3,5, 6,10); 
    onebytwo(0, NAN,0, NAN,0); 
    onebytwo(NAN, 0,0, NAN,NAN); 
    return 0; 
} 

Wyjście Blas:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [0]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [0, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

Wyjście z atlasu:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [nan]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [nan, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

BLAS Wydaje się, że oczekiwane zachowanie, gdy pierwszy operand ma NaN, a nie tak, gdy pierwszy operand wynosi zero, a drugi ma NaN.

W każdym razie, nie sądzę, że ten błąd jest w warstwie Numpy; jest w BLAS. Wygląda na to, że można obejść, używając zamiast tego ATLAS.

Powyżej wygenerowany na Ubuntu 14.04, używając dostarczonego przez Ubuntu gcc, BLAS i ATLAS.