2016-06-04 15 views
7

Mam numpy tablicy rozmiar (4, X, Y), gdzie pierwszy wymiar oznacza (R, G, B, A) czworak. Moim celem jest przetransponowanie każdego kwadratu RGBA o wartości X*Y na wartości zmiennoprzecinkowe X*Y, z uwzględnieniem odpowiadającego im słownika.Poprawianie wydajności operacji numpy mapowania

Mój bieżący kod wygląda następująco:

codeTable = { 
    (255, 255, 255, 127): 5.5, 
    (128, 128, 128, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

Gdzie data jest do tablicy numpy wielkości (4, rows, cols) i new_data jest od wielkości (rows, cols).

Kod działa poprawnie, ale zajmuje dość dużo czasu. Jak zoptymalizować ten fragment kodu?

Oto pełny przykład:

import numpy 

codeTable = { 
    (253, 254, 255, 127): 5.5, 
    (128, 129, 130, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

# test data 
rows = 2 
cols = 2 
data = numpy.array([ 
    [[253, 0], [128, 0], [128, 0]], 
    [[254, 0], [129, 144], [129, 0]], 
    [[255, 0], [130, 243], [130, 5]], 
    [[127, 0], [255, 120], [255, 5]], 
]) 

new_data = numpy.zeros((rows,cols), numpy.float32) 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

# expected result for `new_data`: 
# array([[ 5.50000000e+00, 7.50000000e+00], 
#  [ 6.50000000e+00, -9.99900000e+03], 
#  [ 6.50000000e+00, -9.99900000e+03], dtype=float32) 
+0

Jak jest wiele 'wierszy' i' cols'? – Will

+0

@Will Wiele tysięcy dla każdego. –

+0

Może to pomoże: http://stackoverflow.com/questions/36480358/whats-a-fast-non-loop-way-to-apply-a-dict-to-a-ndarray-meaning-use-elements – hpaulj

Odpowiedz

1

Oto podejście, które zwraca swoją oczekiwanego rezultatu, ale z tak małej ilości danych trudno wiedzieć, czy to będzie szybciej dla Ciebie. Ponieważ jednak uniknąłem podwójnej pętli, wyobrażam sobie, że zobaczysz całkiem przyzwoite przyspieszenie.

import numpy 
import pandas as pd 


codeTable = { 
    (253, 254, 255, 127): 5.5, 
    (128, 129, 130, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

# test data 
rows = 3 
cols = 2 
data = numpy.array([ 
    [[253, 0], [128, 0], [128, 0]], 
    [[254, 0], [129, 144], [129, 0]], 
    [[255, 0], [130, 243], [130, 5]], 
    [[127, 0], [255, 120], [255, 5]], 
]) 

new_data = numpy.zeros((rows,cols), numpy.float32) 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

def create_output(data): 
    # Reshape your two data sources to be a bit more sane 
    reshaped_data = data.reshape((4, -1)) 
    df = pd.DataFrame(reshaped_data).T 

    reshaped_codeTable = [] 
    for key in codeTable.keys(): 
     reshaped = list(key) + [codeTable[key]] 
     reshaped_codeTable.append(reshaped) 
    ct = pd.DataFrame(reshaped_codeTable) 

    # Merge on the data, replace missing merges with -9999 
    result = df.merge(ct, how='left') 
    newest_data = result[4].fillna(-9999) 

    # Reshape 
    output = newest_data.reshape(rows, cols) 
    return output 

output = create_output(data) 
print(output) 
# array([[ 5.50000000e+00, 7.50000000e+00], 
#  [ 6.50000000e+00, -9.99900000e+03], 
#  [ 6.50000000e+00, -9.99900000e+03]) 

print(numpy.array_equal(new_data, output)) 
# True 
+0

Twoje rozwiązanie wydaje się działać tylko dla kwadratowych danych wejściowych i nie działa, gdy 'cols! = Rows'. Ale dzięki za pomysły, zbadam sprawę. W każdym razie prędkość jest o wiele bardziej satysfakcjonująca niż moje naiwne rozwiązanie z podwójną pętlą. –

+0

Naprawiono! Spowoduje to teraz pobranie wymaganej liczby wierszy i kolumn. –

+0

Cóż, twój kod nie działa dla innych kształtów danych. Zaktualizowałem swoją pierwszą wiadomość bardziej skomplikowanym przykładem. Twój kod zwraca poprawne wyniki, ale w niewłaściwej pozycji w tablicy wyjściowej. –

1

Pakiet numpy_indexed (disclaimer: Jestem jego autorem) zawiera vectorized ND-array zdolny wariant list.index, które mogą być używane do rozwiązywania problemu sprawnie i zwięźle:

import numpy_indexed as npi 
map_keys = np.array(list(codeTable.keys())) 
map_values = np.array(list(codeTable.values())) 
indices = npi.indices(map_keys, data.reshape(4, -1).T, missing='mask') 
remapped = np.where(indices.mask, -9999, map_values[indices.data]).reshape(data.shape[1:]) 
+0

Twoje rozwiązanie działa jak czar. Dzięki! Później omówię ulepszenia wydajności. –

+0

Czekamy na porównanie wydajności! –

Powiązane problemy