2016-05-24 12 views
7

Rozważmy dwie ndarrays o długości n, arr1 i arr2. Mam następujące obliczenia sumy produktów, a robi to num_runs razy do odniesienia:Efektywna podwójna suma produktów

import numpy as np 
import time 

num_runs = 1000 
n = 100 

arr1 = np.random.rand(n) 
arr2 = np.random.rand(n) 

start_comp = time.clock() 
for r in xrange(num_runs): 
    sum_prods = np.sum([arr1[i]*arr2[j] for i in xrange(n) 
         for j in xrange(i+1, n)]) 

print "total time for comprehension = ", time.clock() - start_comp 

start_loop = time.clock() 
for r in xrange(num_runs): 
    sum_prod = 0.0 
    for i in xrange(n): 
     for j in xrange(i+1, n): 
      sum_prod += arr1[i]*arr2[j] 

print "total time for loop = ", time.clock() - start_loop 

Wyjście jest

total time for comprehension = 3.23097066953 
total time for comprehension = 3.9045544426 

więc korzystanie listowych pojawia się szybciej.

Czy istnieje o wiele bardziej wydajna implementacja, używając procedur Numpy, aby obliczyć taką sumę produktów?

+0

Czy to może być przydatne? https://stackoverflow.com/questions/9068478/how-to-parallelize-a-sum-calculation-in-python-numpy –

+0

Wydaje się bardzo istotne: ['Mnożenie macierzy z zależnością iteracyjną - NumPy'] (http: // stackoverflow.com/questions/36045510/matrix-multiplication-with-iterator-dependency-numpy). – Divakar

Odpowiedz

12

Przegrupowanie działanie w algorytmie O (n) wykonawczego zamiast O (N^2) i wykorzystać NumPy dla produktów i sum:

# arr1_weights[i] is the sum of all terms arr1[i] gets multiplied by in the 
# original version 
arr1_weights = arr2[::-1].cumsum()[::-1] - arr2 

sum_prods = arr1.dot(arr1_weights) 

czasowy wskazuje, że jest to o 200 razy szybszy niż zrozumienie listy dla n == 100.

In [21]: %%timeit 
    ....: np.sum([arr1[i] * arr2[j] for i in range(n) for j in range(i+1, n)]) 
    ....: 
100 loops, best of 3: 5.13 ms per loop 

In [22]: %%timeit 
    ....: arr1_weights = arr2[::-1].cumsum()[::-1] - arr2 
    ....: sum_prods = arr1.dot(arr1_weights) 
    ....: 
10000 loops, best of 3: 22.8 µs per loop 
+1

Gratulacje. –

+1

Układając terminy, jest to również: 'arr1 [: - 1] .cumsum(). Kropka (arr2 [1:])'. –

3

Można użyć następującego nadawania trick:

a = np.sum(np.triu(arr1[:,None]*arr2[None,:],1)) 
b = np.sum([arr1[i]*arr2[j] for i in xrange(n) for j in xrange(i+1, n)]) 
print a == b # True 

Zasadniczo, płacę cenę obliczania iloczyn wszystkich elementów parami w arr1 i arr2 do skorzystania z prędkością numpy nadawczej/wektoryzacji dzieje się znacznie szybciej w kodzie niskiego poziomu.

A czasy:

%timeit np.sum(np.triu(arr1[:,None]*arr2[None,:],1)) 
10000 loops, best of 3: 55.9 µs per loop 

%timeit np.sum([arr1[i]*arr2[j] for i in xrange(n) for j in xrange(i+1, n)]) 
1000 loops, best of 3: 1.45 ms per loop 
+0

Lub, aby zapisać na pamięci, choć nieco wolniej może być: 'R, C = np.triu_indices (n, 1); output = np.dot (arr1 [R], arr2 [C]) '. – Divakar

8

Vectorized sposób: np.sum(np.triu(np.multiply.outer(arr1,arr2),1)).

do 30x poprawy:

In [9]: %timeit np.sum(np.triu(np.multiply.outer(arr1,arr2),1)) 
1000 loops, best of 3: 272 µs per loop 

In [10]: %timeit np.sum([arr1[i]*arr2[j] for i in range(n) 
         for j in range(i+1, n)] 
100 loops, best of 3: 7.9 ms per loop 

In [11]: allclose(np.sum(np.triu(np.multiply.outer(arr1,arr2),1)), 
np.sum(np.triu(np.multiply.outer(arr1,arr2),1))) 
Out[11]: True 

Kolejna szybka approch jest użycie Numba:

from numba import jit 
@jit 
def t(arr1,arr2): 
    s=0 
    for i in range(n): 
     for j in range(i+1,n): 
      s+= arr1[i]*arr2[j] 
    return s 

do 10x nowy czynnik:

In [12]: %timeit t(arr1,arr2) 
10000 loops, best of 3: 21.1 µs per loop 

i korzystania @ user2357112 minimalnym odpowiedź ,

@jit 
def t2357112(arr1,arr2): 
    s=0 
    c=0 
    for i in range(n-2,-1,-1): 
     c += arr2[i+1] 
     s += arr1[i]*c 
    return s 

dla

In [13]: %timeit t2357112(arr1,arr2) 
100000 loops, best of 3: 2.33 µs per loop 

, po prostu robi niezbędnych operacji.

+0

Rozwiązanie numba jest miłe, ponieważ nie wymaga tworzenia pośrednich macierzy. – JoshAdel

+0

Myślę, że mogłeś uzyskać błędne granice podczas tłumaczenia mojego kodu na numba. Próbujesz uzyskać dostęp do 'arr2 [n]', za ostatnim elementem 'arr2'. – user2357112

+0

masz rację. Ponieważ numba nie sprawdza granic, dało to po cichu odpowiedni wynik na mojej próbie ...... zredagowanej. –