2013-08-26 9 views
10

Używam numpy, gdzie funkcja wiele razy wewnątrz kilku pętli for, ale staje się zbyt wolny. Czy są jakieś sposoby na szybsze wykonanie tej funkcji? Przeczytałem, że powinieneś spróbować zrobić pętlę w linii, a także utworzyć zmienne lokalne dla funkcji przed pętlami for, ale nic nie poprawia prędkości o wiele (< 1%). Numery ndarrays o kształcie = (2600.5200) są numpy ndarrays. Użyłem import profile, aby uzyskać uchwyt tam, gdzie są wąskie gardła, i where w for pętli jest duża.szybkie python numpy gdzie funkcjonalność?

import numpy as np 
max = np.max 
where = np.where 
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS)] 
+0

Czy jesteś obliczanie "UNIQ_IDS" w tym skrypcie lub jest to z góry określone? – Daniel

+0

UNIQ_IDS jest z góry ustaloną ... listą ints of len = 800.To tylko fragment kodu, przepraszam za zamieszanie. –

Odpowiedz

2

nie można po prostu zrobić

emiss_data[obj_data == i] 

? Nie wiem, dlaczego w ogóle używasz where.

+0

Cóż, to działa i poprawia się o ~ 45%. Dzięki. Myślę, że używam gdzie, ponieważ jestem tak przyzwyczajony do IDL i próbuję przekonwertować do Pythona. Jednak wciąż jest bardzo powolny. Wykonanie tego 800 razy zajmuje 75 sekund, podczas gdy IDL wykona to w 2 sekundach. A co, jeśli faktycznie potrzebujesz lokalizacji/indeksów dla przyszłych operacji? Nie wyobrażam sobie, że byłoby to bardzo wydajne, gdybyś używał go kilka razy w pętli for zamiast w instrukcji where w pętli for. –

+0

Wygląda na to, że powinien istnieć sposób grupowania wartości 'emiss_data' przez wartości' obj_data' z wbudowanymi numpy. Ale nie znalazłem. – user2357112

+0

Możesz użyć 'np.lexsort'; jednak samo "lexsort" jest wąskim gardłem prowadzącym do nieoptymalnego rozwiązania. – Daniel

0

Przypisywanie krotki jest znacznie szybsze niż przypisywanie listy, zgodnie z Are tuples more efficient than lists in Python?, więc może po prostu budowanie krotki zamiast listy poprawi efektywność.

+1

Wątpię. Krotki mają zalety w niektórych przypadkach, ale żaden z nich nie ma tu zastosowania. To pytanie (lub raczej akceptowana tam odpowiedź) nie pokazuje, że krotki są szybsze do skonstruowania, pokazuje, że * literalne * krotki można skonstruować raz i użyć wiele razy. Nawet jeśli tworzenie krotek * było * szybsze niż tworzenie listy, nie ma możliwości, żeby to było wąskie gardło. – delnan

+0

Dzięki za informacje! – Jblasco

7

Okazuje się, że czysta pętla w języku Python może być o wiele szybsza niż indeksowanie NumPy (lub wywoływanie do np.where) w tym przypadku.

Rozważmy następujące alternatywy:

import numpy as np 
import collections 
import itertools as IT 

shape = (2600,5200) 
# shape = (26,52) 
emiss_data = np.random.random(shape) 
obj_data = np.random.random_integers(1, 800, size=shape) 
UNIQ_IDS = np.unique(obj_data) 

def using_where(): 
    max = np.max 
    where = np.where 
    MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_index(): 
    max = np.max 
    MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_max(): 
    MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_loop(): 
    result = collections.defaultdict(list) 
    for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()): 
     result[idx].append(val) 
    return [max(result[idx]) for idx in UNIQ_IDS] 

def using_sort(): 
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals = uind.argsort() 
    count = np.bincount(uind) 
    start = 0 
    end = 0 
    out = np.empty(count.shape[0]) 
    for ind, x in np.ndenumerate(count): 
     end += x 
     out[ind] = np.max(np.take(emiss_data, vals[start:end])) 
     start += x 
    return out 

def using_split(): 
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals = uind.argsort() 
    count = np.bincount(uind) 
    return [np.take(emiss_data, item).max() 
      for item in np.split(vals, count.cumsum())[:-1]] 

for func in (using_index, using_max, using_loop, using_sort, using_split): 
    assert using_where() == func() 

Oto benchmarki, z shape = (2600,5200):

In [57]: %timeit using_loop() 
1 loops, best of 3: 9.15 s per loop 

In [90]: %timeit using_sort() 
1 loops, best of 3: 9.33 s per loop 

In [91]: %timeit using_split() 
1 loops, best of 3: 9.33 s per loop 

In [61]: %timeit using_index() 
1 loops, best of 3: 63.2 s per loop 

In [62]: %timeit using_max() 
1 loops, best of 3: 64.4 s per loop 

In [58]: %timeit using_where() 
1 loops, best of 3: 112 s per loop 

Zatem using_loop (czysty Python) okazuje się być ponad 11x szybciej niż using_where.

Nie jestem do końca pewien, dlaczego czysty Python jest szybszy niż NumPy. Domyślam się, że zamki w czystej wersji Pythona (tak, kalambur przeznaczone) przez jeden raz. Wykorzystuje fakt, że pomimo wszystkich wymyślnych indeksów, , naprawdę chcemy po prostu odwiedzić każdą wartość raz. W ten sposób rozwiązuje problem z koniecznością dokładnego określenia, do której grupy przypada każda wartość w emiss_data. Jest to jednak tylko niejasna spekulacja. Nie wiedziałem, że to będzie szybsze, dopóki nie będę porównywać.

+0

co to jest "lista" w using_loop? –

+0

['collections.defaultdict (list)'] (http://docs.python.org/2/library/collections.html#collections.defaultdict) tworzy obiekt podobny do dyktafonu, który zwraca listę jako wartość domyślną. – unutbu

7

można użyć np.unique z return_index:

def using_sort(): 
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) 
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals=uind.argsort() 
    count=np.bincount(uind) 

    start=0 
    end=0 

    out=np.empty(count.shape[0]) 
    for ind,x in np.ndenumerate(count): 
     end+=x 
     out[ind]=np.max(np.take(emiss_data,vals[start:end])) 
     start+=x 
    return out 

Korzystanie odpowiedź @ unutbu jako punkt odniesienia dla shape = (2600,5200):

np.allclose(using_loop(),using_sort()) 
True 

%timeit using_loop() 
1 loops, best of 3: 12.3 s per loop 

#With np.unique inside the definition 
%timeit using_sort() 
1 loops, best of 3: 9.06 s per loop 

#With np.unique outside the definition 
%timeit using_sort() 
1 loops, best of 3: 2.75 s per loop 

#Using @Jamie's suggestion for uind 
%timeit using_sort() 
1 loops, best of 3: 6.74 s per loop 
+2

Myślę, że jeśli 'UNIQ_IDS' rzeczywiście ma unikatowe wpisy' obj_data' wstępnie obliczone, możesz wywołać 'np.digitize (obj_data, UNIQ_IDS) - 1', aby uzyskać ten sam wynik co' uind' w mniej więcej połowie czasu. – Jaime

+0

Twoja metoda jest naprawdę sprytna, ale niestety nie mogę uzyskać takiego samego przyrostu prędkości. (Dodałem benchmark dla 'using_sort', gdy uruchamiam na moim komputerze w moim poście.) Dla mnie' using_loop' jest wciąż nieco szybszy.) Być może różnica wynika z wersji Pythona lub systemu operacyjnego? Używam Pythona 2.7 na Ubuntu 11.10. Czego używasz? – unutbu

+0

@unutbu Używam OSX i w pełni zaktualizowanej instalacji anakonda (ma przyspieszenie, o którym wiem, że w przeszłości skręcało czasy). Próbowałem również z python 2.7.4 i numpy 1.7.1 na pudełku OSX i uzyskałem te same wyniki; jednak próbowałem na systemie Ubuntu z układem AMD z numpy 1.6.1 i stwierdziłem, że czasy są równoważne. Nienawidzę nadal zadawać pytanie [this] (http://stackoverflow.com/questions/18365073/why-is-numpys-einsum-faster-than-numpys-built-infunction), ale wydaje się, że coś się dzieje z czasami, których nie rozumiem. – Daniel

5

wierzę najszybszym sposobem osiągnięcia tego jest użycie operacji w groupby()pandas pakiet. Porównując do @using_sort() funkcję Ophion, w Pandy jest około 10-krotnie szybciej:

import numpy as np 
import pandas as pd 

shape = (2600,5200) 
emiss_data = np.random.random(shape) 
obj_data = np.random.random_integers(1, 800, size=shape) 
UNIQ_IDS = np.unique(obj_data) 

def using_sort(): 
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) 
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals=uind.argsort() 
    count=np.bincount(uind) 

    start=0 
    end=0 

    out=np.empty(count.shape[0]) 
    for ind,x in np.ndenumerate(count): 
     end+=x 
     out[ind]=np.max(np.take(emiss_data,vals[start:end])) 
     start+=x 
    return out 

def using_pandas(): 
    return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max() 

print('same results:', np.allclose(using_pandas(), using_sort())) 
# same results: True 

%timeit using_sort() 
# 1 loops, best of 3: 3.39 s per loop 

%timeit using_pandas() 
# 1 loops, best of 3: 397 ms per loop 
0

Jeśli obj_data składa się ze stosunkowo małych liczb całkowitych, można wykorzystać numpy.maximum.at (od v1.8.0):

def using_maximumat(): 
    n = np.max(UNIQ_IDS) + 1 
    temp = np.full(n, -np.inf) 
    np.maximum.at(temp, obj_data, emiss_data) 
    return temp[UNIQ_IDS]