sum(x.*y',2)
to czyste, krótkie rozwiązanie.
Ma również dobre właściwości prędkości i pamięci. Sztuką jest widok mnożenia macierzy-wektora jako liniowej kombinacji kolumn macierzy skalowanych przez elementy wektorowe. Zamiast wykonywać każdą liniową kombinację dla macierzy x [:,:, i], używamy tej samej skali y [i] dla x [:, i ,:]. W kodzie:
const x = rand(6,6,2^10);
const y = rand(6,1);
function tst(x,y)
z = zeros(6,1,2^10)
for i in 1:2^10
z[:,:,i] = x[:,:,i]*y
end
return z
end
tst2(x,y) = mapslices(i->i*y,x,(1,2))
tst3(x,y) = sum(x.*y',2)
Benchmarking daje:
julia> using BenchmarkTools
julia> z = tst(x,y); z2 = tst2(x,y); z3 = tst3(x,y);
julia> @benchmark tst(x,y)
BenchmarkTools.Trial:
memory estimate: 688.11 KiB
allocs estimate: 8196
--------------
median time: 759.545 μs (0.00% GC)
samples: 6068
julia> @benchmark tst2(x,y)
BenchmarkTools.Trial:
memory estimate: 426.81 KiB
allocs estimate: 10798
--------------
median time: 1.634 ms (0.00% GC)
samples: 2869
julia> @benchmark tst3(x,y)
BenchmarkTools.Trial:
memory estimate: 336.41 KiB
allocs estimate: 12
--------------
median time: 114.060 μs (0.00% GC)
samples: 10000
Więc tst3
użyciu sum
ma lepszą wydajność (~ 7x nad tst
i ~ 15x ponad tst2
).
Używanie StaticArrays
zgodnie z sugestią @DNF jest również opcją i dobrze byłoby porównać je z rozwiązaniami tutaj.
Jak działa "_" w zrozumieniu? –
To po prostu sztuczna zmienna. Mógłbym użyć na przykład 'i', ale często można powiedzieć, że' _' oznacza zmienną jednorazową, która nie jest używana dalej i która nazwa jest nieważna. – DNF