Okazuje się, że czysta pętla w języku Python może być o wiele szybsza niż indeksowanie NumPy (lub wywoływanie do np.where) w tym przypadku.
Rozważmy następujące alternatywy:
import numpy as np
import collections
import itertools as IT
shape = (2600,5200)
# shape = (26,52)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)
def using_where():
max = np.max
where = np.where
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]
return MAX_EMISS
def using_index():
max = np.max
MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]
return MAX_EMISS
def using_max():
MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]
return MAX_EMISS
def using_loop():
result = collections.defaultdict(list)
for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):
result[idx].append(val)
return [max(result[idx]) for idx in UNIQ_IDS]
def using_sort():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
start = 0
end = 0
out = np.empty(count.shape[0])
for ind, x in np.ndenumerate(count):
end += x
out[ind] = np.max(np.take(emiss_data, vals[start:end]))
start += x
return out
def using_split():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
return [np.take(emiss_data, item).max()
for item in np.split(vals, count.cumsum())[:-1]]
for func in (using_index, using_max, using_loop, using_sort, using_split):
assert using_where() == func()
Oto benchmarki, z shape = (2600,5200)
:
In [57]: %timeit using_loop()
1 loops, best of 3: 9.15 s per loop
In [90]: %timeit using_sort()
1 loops, best of 3: 9.33 s per loop
In [91]: %timeit using_split()
1 loops, best of 3: 9.33 s per loop
In [61]: %timeit using_index()
1 loops, best of 3: 63.2 s per loop
In [62]: %timeit using_max()
1 loops, best of 3: 64.4 s per loop
In [58]: %timeit using_where()
1 loops, best of 3: 112 s per loop
Zatem using_loop
(czysty Python) okazuje się być ponad 11x szybciej niż using_where
.
Nie jestem do końca pewien, dlaczego czysty Python jest szybszy niż NumPy. Domyślam się, że zamki w czystej wersji Pythona (tak, kalambur przeznaczone) przez jeden raz. Wykorzystuje fakt, że pomimo wszystkich wymyślnych indeksów, , naprawdę chcemy po prostu odwiedzić każdą wartość raz. W ten sposób rozwiązuje problem z koniecznością dokładnego określenia, do której grupy przypada każda wartość w emiss_data
. Jest to jednak tylko niejasna spekulacja. Nie wiedziałem, że to będzie szybsze, dopóki nie będę porównywać.
Czy jesteś obliczanie "UNIQ_IDS" w tym skrypcie lub jest to z góry określone? – Daniel
UNIQ_IDS jest z góry ustaloną ... listą ints of len = 800.To tylko fragment kodu, przepraszam za zamieszanie. –