2013-08-27 12 views
10

Aby przyspieszyć moje dzielenie bitu, muszę przyspieszyć operację y = x^2 dla dużych obiektów, które są reprezentowane jako dynamiczne tablice niepodpisanych DWORD. Żeby było jasne:Szybkie obliczanie kwadratu binarnego

DWORD x[n+1] = { LSW, ......, MSW }; 
  • gdzie n + 1 jest liczba użytych DWORDs
  • więc wartość liczby x = x[0]+x[1]<<32 + ... x[N]<<32*(n)

Pytanie brzmi: Jak obliczyć y = x^2 jak najszybciej bez precyzji strat? - Do dyspozycji są: arytmetyka liczb całkowitych (32bit z Carry). Używanie C++.

Moje obecne podejście polega na zastosowaniu mnożenia y = x*x i uniknięcia wielokrotnych multiplikacji.

Na przykład:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n) 

Dla uproszczenia, niech mi przepisać go:

x = x0+ x1 + x2 + ... + xn 

gdzie indeks reprezentować adres wewnątrz tablicy, więc:

y = x*x 
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn) 
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn) 

y0  = x0*x0 
y1  = x1*x0 + x0*x1 
y2  = x2*x0 + x1*x1 + x0*x2 
y3  = x3*x0 + x2*x1 + x1*x2 
... 
y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2) 
y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1) 
y(2n-1) = xn(n )*x(n ) 

Po bliższe Wygląda na to, że prawie wszystkie xi*xj pojawia się dwa razy (nie pierwszy i ostatni), co oznacza to, że mnożniki N*N można zastąpić multiplikacjami (N+1)*(N/2). P.S. 32bit*32bit = 64bit, więc wynik każdej operacji mul+add jest obsługiwany jako 64+1 bit.

Czy istnieje lepszy sposób obliczenia tego szybko? Wszystkie znalazłem podczas poszukiwań były sqrts algorytmy, nie SQR ...

Szybka sqr

!!! Uważaj, że wszystkie liczby w moim kodzie są najpierw MSW, ... nie jak w powyższym teście (najpierw są LSW dla prostoty równań, w przeciwnym razie byłby to bałagan indeksowy).

Aktualny fsqr funkcjonalne wdrożenie

void arbnum::sqr(const arbnum &x) 
{ 
    // O((N+1)*N/2) 
    arbnum c; 
    DWORD h, l; 
    int N, nx, nc, i, i0, i1, k; 
    c._alloc(x.siz + x.siz + 1); 
    nx = x.siz - 1; 
    nc = c.siz - 1; 
    N = nx + nx; 
    for (i=0; i<=nc; i++) 
     c.dat[i]=0; 
    for (i=1; i<N; i++) 
     for (i0=0; (i0<=nx) && (i0<=i); i0++) 
     { 
      i1 = i - i0; 
      if (i0 >= i1) 
       break; 
      if (i1 > nx) 
       continue; 
      h = x.dat[nx-i0]; 
      if (!h) 
       continue; 
      l = x.dat[nx-i1]; 
      if (!l) 
       continue; 
      alu.mul(h, l, h, l); 
      k = nc - i; 
      if (k >= 0) 
       alu.add(c.dat[k], c.dat[k], l); 
      k--; 
      if (k>=0) 
       alu.adc(c.dat[k], c.dat[k],h); 
      k--; 
      for (; (alu.cy) && (k>=0); k--) 
       alu.inc(c.dat[k]); 
     } 
     c.shl(1); 
     for (i = 0; i <= N; i += 2) 
     { 
      i0 = i>>1; 
      h = x.dat[nx-i0]; 
      if (!h) 
       continue; 
      alu.mul(h, l, h, h); 
      k = nc - i; 
      if (k >= 0) 
       alu.add(c.dat[k], c.dat[k],l); 
      k--; 
      if (k>=0) 
       alu.adc(c.dat[k], c.dat[k], h); 
      k--; 
      for (; (alu.cy) && (k >= 0); k--) 
       alu.inc(c.dat[k]); 
     } 
     c.bits = c.siz<<5; 
     c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1; 
     c.sig = sig; 
     *this = c; 
    } 

Wykorzystanie Karatsuba mnożenia

(dzięki Calpis)

I wdrożone Karatsuba mnożenia ale wyniki są znacznie wolniejsze nawet niż przy wykorzystaniu prostych O(N^2) mnożenie, prawdopodobnie z powodu tej strasznej rekurencji, której nie widzę w żaden sposób. Kompromis musi mieć naprawdę duże liczby (większe niż setki cyfr) ... ale nawet wtedy istnieje wiele transferów pamięci. Czy istnieje sposób na uniknięcie wywołań rekurencyjnych (wariant nierekursywny, ... Prawie wszystkie algorytmy rekursywne mogą być wykonane w ten sposób). Mimo to postaram się poprawić i zobaczyć, co się stanie (uniknąć normalizacji, itp.), Może to być również głupi błąd w kodzie). W każdym razie, po rozwiązaniu Karatsuba dla przypadku x*x, nie ma zbyt dużego przyrostu wydajności.

Zoptymalizowany Karatsuba mnożenie

test wydajności dla y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits:

x = 0.00000009876... | 98*32 bits 
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr 
mul1[ 363.472 ms ] ... O(N^2) classic multiplication 
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication 
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication 

x = 0.00... | 195*32 bits 
sqr [ 883.01 ms ] 
mul1[ 1427.02 ms ] 
mul2[ 1089.84 ms ] 

x = 0.00... | 389*32 bits 
sqr [ 3189.19 ms ] 
mul1[ 5553.23 ms ] 
mul2[ 3159.07 ms ] 

Po optymalizacje dla Karatsuba, kod jest znacznie szybciej niż wcześniej. Jednak dla mniejszych liczb jest to nieco mniej niż połowa prędkości mojego mnożenia O(N^2). Dla większych liczb jest on szybszy ze względu na stopień skomplikowania multipleksacji Bootha. Próg dla mnożenia wynosi około 32 * 98 bitów, a dla sqr około 32 * 389 bitów, więc jeśli suma bitów wejściowych przekroczy ten próg, to mnożenie Karatsuba będzie używane do przyspieszenia mnożenia i to samo będzie podobne dla sqr.

okazji, optymalizacji zawiera:

  • zminimalizować sterty koszu przez zbyt duże rekursji argumentu
  • uniknięcia jakichkolwiek aritmetics bignum (+, -) 32-bitowy aluminium z C jest używany.
  • Pomijając 0*y lub x*0 lub przypadki
  • Ponowne wejściowego x,y liczbę rozmiarów do potęgi dwójki, aby uniknąć ponownego przydzielania
  • Wdrożenie modulo mnożenia dla z1 = (x0 + x1)*(y0 + y1) zminimalizować rekursji

Modified Schönhage-Strassen rozmnożenia SQR implementacja

Mam przetestowane użycie FFT i NTT przekształca się, aby przyspieszyć obliczenia sqr. Wyniki są następujące:

  1. FFT

    utratą dokładności i dlatego potrzebujemy wysokiej precyzji liczb zespolonych. To faktycznie znacznie spowalnia działanie, więc nie występuje żadne przyspieszenie. W rezultacie nie jest jasny (niesłusznie mogą być zaokrąglone), tak FFT nadaje się do użytku (do tej pory)

  2. NTT

    NTT jest skończony zakres DFT a więc nie występuje utrata dokładności. Potrzebuje modułowej arytmetyki dla liczb całkowitych bez znaku: modpow, modmul, modadd i modsub.

    Używam DWORD (32-bitowe liczby całkowite bez znaku). Rozmiar wektora wejściowego/otput wektora NTT jest ograniczony z powodu problemów z przepełnieniem!32-bitowe arytmetyki modułowych N jest ograniczony do (2^32)/(max(input[])^2) tak bigint musi być podzielona na mniejsze kawałki (użyć BYTES więc maksymalna wielkość bigint przetwarzane jest

    (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs) 
    

    sqr wykorzystuje tylko 1xNTT + 1xINTT zamiast 2xNTT + 1xINTT mnożenia ale NTT użycie jest zbyt wolne, a rozmiar progu jest zbyt duży, aby mógł być wykorzystany w praktyce (dla mul, a także dla sqr) .

    Jest to możliwe Jest to nawet ponad limitem przelewu, więc należy zastosować 64-bitową modułową arytmetykę, która może jeszcze bardziej spowolnić proces. Więc NTT jest również dla moich celów nie do użytku.

Niektóre pomiary:

a = 0.00 | 389*32 bits 
looped 1x times 
sqr1[ 3.177 ms ] fast sqr 
sqr2[ 720.419 ms ] NTT sqr 
mul1[ 5.588 ms ] simpe mul 
mul2[ 3.172 ms ] karatsuba mul 
mul3[ 1053.382 ms ] NTT mul 

Moje wykonanie:

void arbnum::sqr_NTT(const arbnum &x) 
{ 
    // O(N*log(N)*(log(log(N)))) - 1x NTT 
    // Schönhage-Strassen sqr 
    // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!! 
    int i, j, k, n; 
    int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2; 
    i = x.siz; 
    for (n = 1; n < i; n<<=1) 
     ; 
    if (n + n > 0x3000) { 
     _error(_arbnum_error_TooBigNumber); 
     zero(); 
     return; 
    } 
    n <<= 3; 
    DWORD *xx, *yy, q, qq; 
    xx = new DWORD[n+n]; 
    #ifdef _mmap_h 
    if (xx) 
     mmap_new(xx, (n+n) << 2); 
    #endif 
    if (xx==NULL) { 
     _error(_arbnum_error_NotEnoughMemory); 
     zero(); 
     return; 
    } 
    yy = xx + n; 

    // Zero padding (and split DWORDs to BYTEs) 
    for (i--, k=0; i >= 0; i--) 
    { 
     q = x.dat[i]; 
     xx[k] = q&0xFF; k++; q>>=8; 
     xx[k] = q&0xFF; k++; q>>=8; 
     xx[k] = q&0xFF; k++; q>>=8; 
     xx[k] = q&0xFF; k++; 
    } 
    for (;k<n;k++) 
     xx[k] = 0; 

    //NTT 
    fourier_NTT ntt; 

    ntt.NTT(yy,xx,n); // init NTT for n 

    // Convolution 
    for (i=0; i<n; i++) 
     yy[i] = modmul(yy[i], yy[i], ntt.p); 

    //INTT 
    ntt.INTT(xx, yy); 

    //suma 
    q=0; 
    for (i = 0, j = 0; i<n; i++) { 
     qq = xx[i]; 
     q += qq&0xFF; 
     yy[n-i-1] = q&0xFF; 
     q>>=8; 
     qq>>=8; 
     q+=qq; 
    } 

    // Merge WORDs to DWORDs and copy them to result 
    _alloc(n>>2); 
    for (i = 0, j = 0; i<siz; i++) 
    { 
     q =(yy[j]<<24)&0xFF000000; j++; 
     q |=(yy[j]<<16)&0x00FF0000; j++; 
     q |=(yy[j]<< 8)&0x0000FF00; j++; 
     q |=(yy[j] )&0x000000FF; j++; 
     dat[i] = q; 
    } 

    #ifdef _mmap_h 
    if (xx) 
     mmap_del(xx); 
    #endif 
    delete xx; 
    bits = siz<<5; 
    sig = s; 
    exp = exp0 + (siz<<5) - 1; 
     // _normalize(); 
    } 

Wnioski

Dla mniejszych zdrętwiałe ers, jest to najlepsza opcja mojego szybkiego podejścia sqr, a po próg Mnożenie Karatsuba jest lepsze. Ale nadal uważam, że powinno być coś trywialnego, co przeoczyliśmy. Ma jakieś inne pomysły?

optymalizacja NTT

Po masywnie intensywnych optymalizacje (głównie NTT): przepełnienie stosu pytanie Modular arithmetics and NTT (finite field DFT) optimizations.

Niektóre wartości uległy zmianie:

a = 0.00 | 1553*32bits 
looped 10x times 
mul2[ 28.585 ms ] Karatsuba mul 
mul3[ 26.311 ms ] NTT mul 

Więc teraz NTT mnożenie jest wreszcie szybciej niż Karatsuba po progu około 1500 * 32-bitowym.

Niektóre pomiary i błąd zauważony

a = 0.99991970486 | 1553*32 bits 
looped: 10x 
sqr1[ 58.656 ms ] fast sqr 
sqr2[ 13.447 ms ] NTT sqr 
mul1[ 102.563 ms ] simpe mul 
mul2[ 28.916 ms ] Karatsuba mul Error 
mul3[ 19.470 ms ] NTT mul 

I okazało się, że moja Karatsuba (powyżej/poniżej) przepływa LSB każdego DWORD segmencie bignum. Kiedy zbadali, będę aktualizować kod ...

Również po dalszym NTT optymalizacje progi zmianie, więc dla NTT sqr jest 310*32 bits = 9920 bits z argumentu, a dla NTT mul go to 1396*32 bits = 44672 bits z wynik (suma bitów operandów).

kod Karatsuba naprawione dzięki @greybeard

//--------------------------------------------------------------------------- 
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n) 
{ 
    // Recursion for Karatsuba 
    // z[2n] = x[n]*y[n]; 
    // n=2^m 
    int i; 
    for (i=0; i<n; i++) 
     if (x[i]) { 
      i=-1; 
      break; 
     } // x==0 ? 

    if (i < 0) 
     for (i = 0; i<n; i++) 
      if (y[i]) { 
       i = -1; 
       break; 
      } // y==0 ? 

    if (i >= 0) { 
     for (i = 0; i < n + n; i++) 
      z[i]=0; 
      return; 
     } // 0.? = 0 

    if (n == 1) { 
     alu.mul(z[0], z[1], x[0], y[0]); 
     return; 
    } 

    if (n< 1) 
     return; 
    int n2 = n>>1; 
    _mul_karatsuba(z+n, x+n2, y+n2, n2);       // z0 = x0.y0 
    _mul_karatsuba(z , x , y , n2);       // z2 = x1.y1 
    DWORD *q = new DWORD[n<<1], *q0, *q1, *qq; 
    BYTE cx,cy; 
    if (q == NULL) { 
     _error(_arbnum_error_NotEnoughMemory); 
     return; 
    } 
    #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0] 
    #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0] 
    qq = q; 
    q0 = x + n2; 
    q1 = x; 
    i = n2 - 1; 
    _add; 
    cx = alu.cy; // =x0+x1 

    qq = q + n2; 
    q0 = y + n2; 
    q1 = y; 
    i = n2 - 1; 
    _add; 
    cy = alu.cy; // =y0+y1 

    _mul_karatsuba(q + n, q + n2, q, n2);      // =(x0+x1)(y0+y1) mod ((2^N)-1) 

    if (cx) { 
     qq = q + n; 
     q0 = qq; 
     q1 = q + n2; 
     i = n2 - 1; 
     _add; 
     cx = alu.cy; 
    }// += cx*(y0 + y1) << n2 

    if (cy) { 
     qq = q + n; 
     q0 = qq; 
     q1 = q; 
     i = n2 -1; 
     _add; 
     cy = alu.cy; 
    }// +=cy*(x0+x1)<<n2 

    qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0 
    qq = q + n; q0 = qq; q1 = z;  i = n - 1; _sub; // -=z2 
    qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2 

    DWORD ccc=0; 

    if (alu.cy) 
     ccc++; // Handle carry from last operation 
    if (cx || cy) 
     ccc++; // Handle carry from before last operation 
    if (ccc) 
    { 
     i = n2 - 1; 
     alu.add(z[i], z[i], ccc); 
     for (i--; i>=0; i--) 
      if (alu.cy) 
       alu.inc(z[i]); 
      else 
       break; 
    } 

    delete[] q; 
    #undef _add 
    #undef _sub 
    } 

//--------------------------------------------------------------------------- 
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y) 
{ 
    // O(3*(N)^log2(3)) ~ O(3*(N^1.585)) 
    // Karatsuba multiplication 
    // 
    int s = x.sig*y.sig; 
    arbnum a, b; 
    a = x; 
    b = y; 
    a.sig = +1; 
    b.sig = +1; 
    int i, n; 
    for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1) 
     ; 
    a._realloc(n); 
    b._realloc(n); 
    _alloc(n + n); 
    for (i=0; i < siz; i++) 
     dat[i]=0; 
    _mul_karatsuba(dat, a.dat, b.dat, n); 
    bits = siz << 5; 
    sig = s; 
    exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1; 
    // _normalize(); 
    } 
//--------------------------------------------------------------------------- 

Moje arbnum numer reprezentacji:

// dat is MSDW first ... LSDW last 
DWORD *dat; int siz,exp,sig,bits; 
  • dat[siz] jest mantisa. LSDW oznacza najmniej znaczący DWORD.
  • exp jest wykładnikiem MSB dat[0]
  • Pierwszy niezerowe bit jest obecny w mantysie !!!

    // |-----|---------------------------|---------------|------| 
    // | sig | MSB  mantisa  LSB | exponent | bits | 
    // |-----|---------------------------|---------------|------| 
    // | +1 | 0.(0  ...   0) | 2^0   | 0 | +zero 
    // | -1 | 0.(0  ...   0) | 2^0   | 0 | -zero 
    // |-----|---------------------------|---------------|------| 
    // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp   | n | +number 
    // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp   | n | -number 
    // |-----|---------------------------|---------------|------| 
    // | +1 | 1.0      | 2^+0x7FFFFFFE | 1 | +infinity 
    // | -1 | 1.0      | 2^+0x7FFFFFFE | 1 | -infinity 
    // |-----|---------------------------|---------------|------| 
    
+4

Moje pytanie brzmi: dlaczego zdecydowałeś się wdrożyć własną implementację bignum? [Biblioteka arytmetyczna wielokrotnej precyzji GNU] (http://gmplib.org/) jest prawdopodobnie jedną z najczęściej używanych bibliotek bignum i powinna być całkiem optymalna przy wszystkich operacjach. –

+0

Używam własnych bibliotek bignum ze względu na kompatybilność. Przeniesienie całego kodu do różnych bibliotek jest bardziej czasochłonne, niż mogłoby się wydawać na pierwszy rzut oka (a czasami nawet niemożliwe z powodu niezgodności kompilatora, zwłaszcza z kodem gcc). Obecnie poprawiam wszystko, ... wszystko działa tak, jak powinno, ale zawsze potrzebna jest większa prędkość :) – Spektre

+0

P.S. dla NTT zdecydowanie zalecam, aby NTT było obliczane z 4 razy większą precyzją niż wartości wejściowe (tak dla liczb 8-bitowych trzeba konwertować je na liczby 32-bitowe), aby uzyskać kompromis między maksymalnym rozmiarem macierzy i prędkością – Spektre

Odpowiedz

2

Jeśli rozumiem Twój algorytm poprawnie, wydaje O(n^2) gdzie n jest liczbą cyfr.

Czy obejrzałeś Karatsuba Algorithm? Przyspiesza mnożenie za pomocą metody dziel i podbijaj. Być może warto się przyjrzeć.

+0

miło to przyspiesza rzeczy dużo ... z powodu x = y ... trudno założyć złożoność przed zakodowaniem. – Spektre

+0

z drugiej strony, rozwiązanie karatsuba dla x * x ma taki sam skutek, jak moje podejście :(spróbuję, jeśli bardziej rekurencyjne podejście jest lepsze ... moja złożoność teraz przechodzi od O (n^2) do ~ O (0,5 * N^2), ale według tej strony powinna być niższa – Spektre

+0

OK Sprawdziłem algorytm karatsuba.To jest dobre dla przyspieszenia multiplikacji, ale dla x^2 ma zastosowanie tylko dla naprawdę dużych liczb. Myślę, że powinno być coś prostego i dużo szybciej niż ogólne mnożenie, – Spektre