Hack 使用硬件取 128 位数字的平方根

计算科学 浮点
2021-12-01 02:14:27

我需要取整数平方根n(大量)128 位数字n. 调用 gmp 似乎花费了惊人的时间(虽然我不能确定,因为 gmp 例程没有显示在分析器信息中)。

有没有标准的计算方法n在这种情况下使用硬件(尽可能)?

我能想到的最好的方法是通过https://hal.inria.fr/inria-00072854/en/中的算法的单一应用(即该算法减去递归)将问题减少到 64 位; 它基本上只是一个牛顿迭代)然后计算n在遵循 80 位扩展精度标准sqrtl的系统(例如,在 64 位 x86 系统上运行的 gcc)中使用该函数。long double或者sqrtl不保证给出正确的舍入值n,或者至少是一个有效的舍入值(即,n要么n+1)?

2个回答

我建议从测量 的gmp平方根的执行时间开始,以建立基线性能。gmp世界上一些最重要的算法和性能优化专家为. 这是一个成熟的库做出了贡献。一般来说,我会对试图与这样的图书馆竞争持谨慎态度。事实上,OP 链接的论文中描述的 Paul Zimmermann 的 Karatsuba 平方根算法正是gmp根据其在线文档的平方根实现所使用的算法。此外,该算法已被证明是正确的:

Yves Bertot、Nicolas Magaud 和 Paul Zimmermann,“GMP 平方根的证明”,自动推理杂志,卷。29,第 3-4 期,2002 年 9 月,第 225-252 页(在线稿件

我自己没有gmp安装来计算平方根。鉴于其作为任意精度库的一般性质,gmp相对较短的操作数可能会产生大量开销。x86_64 硬件上超过几百个周期的执行时间将表明这一点,并提供尝试自制版本的动力。

由于 OP 专门要求“黑客”,我做了一个快速的实验黑客代码,用于先前关于 32 位整数平方根到 128 位版本的答案的答案。在我的英特尔至强 E3 1270v2(常春藤桥)上,使用英特尔 C 编译器版本 13.1.3.198 完全优化编译的代码,我测量了大约 440 个周期的执行时间。由于处理器和编译器在过去 7 年中都取得了进步,这无疑应该提供可实现执行时间的上限

在工具链提供原生 128 位整数类型的情况下,应该使用它来代替我的 128 位仿真。此外,可以使用更好的起始近似值,减少完全准确所需的牛顿迭代次数。x86_64就错过的微优化而言,我在用于访问指令mul内联汇编方面遇到了麻烦,div因为我无法获取"a""d"绑定来将操作数直接放入那些使用的raxrdx寄存器中。

在标准数学库实践中,平方根通常通过倒数平方根的 Newton-Raphson(二次收敛)或 Halley 迭代(三次收敛)计算,因为这些迭代无需除法,只需要乘法。但是,以一种健壮的方式制作这样的实现将花费更多的时间,而不是我目前可以负担得起的投资答案。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define MODE_INCR   (1)
#define MODE_DECR   (2)
#define MODE_RANDOM (3)

#define TEST_MODE   (MODE_RANDOM)
#define BENCHMARK   (0)

typedef struct {
    uint64_t lo;
    uint64_t hi;
} my_uint128_t;

/* lookup table for low-accuracy sqrt approximation */
const uint64_t sqrt_tab[32] = 
{ 0x0000000000000000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL, 
  0x0000000000000000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
  0x85ffffffffffffffULL, 0x8cffffffffffffffULL, 0x94ffffffffffffffULL, 0x9affffffffffffffULL, 
  0xa1ffffffffffffffULL, 0xa7ffffffffffffffULL, 0xadffffffffffffffULL, 0xb3ffffffffffffffULL,
  0xb9ffffffffffffffULL, 0xbeffffffffffffffULL, 0xc4ffffffffffffffULL, 0xc9ffffffffffffffULL, 
  0xceffffffffffffffULL, 0xd3ffffffffffffffULL, 0xd8ffffffffffffffULL, 0xdcffffffffffffffULL, 
  0xe1ffffffffffffffULL, 0xe6ffffffffffffffULL, 0xeaffffffffffffffULL, 0xeeffffffffffffffULL, 
  0xf3ffffffffffffffULL, 0xf7ffffffffffffffULL, 0xfbffffffffffffffULL, 0xffffffffffffffffULL
};

uint32_t clz_128 (my_uint128_t x);
my_uint128_t shl_128 (my_uint128_t x, uint32_t shift);
uint64_t udiv_128_64 (my_uint128_t x, uint64_t y);
my_uint128_t umul_64_128 (uint64_t a, uint64_t b);
int gt_128 (my_uint128_t x, my_uint128_t y);
uint64_t avg_64 (uint64_t a, uint64_t b);

/* compute integer square root by initial table lookup refined by division-based
   Newton iterations 
*/ 
uint64_t isqrt_128 (my_uint128_t x)
{
    my_uint128_t t;
    uint64_t q, y;
    uint32_t lz, i;

    if ((x.hi | x.lo) == 0) return x.lo; // early out

    // initial guess based on leading 5 bits of normalized argument
    lz = clz_128 (x);
    t = shl_128 (x, (lz & ~1));
    i = t.hi >> (64 - 5);
    y = sqrt_tab[i] >> (lz >> 1);

    // 1st Newton iteration
    q = 0xffffffffffffffffULL;
    if (x.hi < y) q = udiv_128_64 (x, y);
    y = avg_64 (y, q);

    if (lz < 106) { // 2nd Newton iteration
        q = 0xffffffffffffffffULL;
        if (x.hi < y) q = udiv_128_64 (x, y);
        y = avg_64 (y, q);

        if (lz < 86) { // 3rd Newton iteration
            q = 0xffffffffffffffffULL; 
            if (x.hi < y) q = udiv_128_64 (x, y);
            y = avg_64 (y, q);

            if (lz < 42) { // 4th Newton iteration
                q = 0xffffffffffffffffULL; 
                if (x.hi < y) q = udiv_128_64 (x, y);
                y = avg_64 (y, q);
            }
        }
    }

    if (gt_128 (umul_64_128 (y, y), x)) y--; // adjust quotient if too large

    return y; // (int)sqrt(x)
}

uint32_t clz_128 (my_uint128_t x)
{
    int r = 0;
    if (!(x.lo | x.hi)) return 128;
    if (!(x.hi & 0xffffffffffffffffULL)) { x.hi = x.lo; r += 64; }
    if (!(x.hi & 0xffffffff00000000ULL)) { x.hi <<= 32; r += 32; }
    if (!(x.hi & 0xffff000000000000ULL)) { x.hi <<= 16; r += 16; }
    if (!(x.hi & 0xff00000000000000ULL)) { x.hi <<=  8; r +=  8; }
    if (!(x.hi & 0xf000000000000000ULL)) { x.hi <<=  4; r +=  4; }
    if (!(x.hi & 0xc000000000000000ULL)) { x.hi <<=  2; r +=  2; }
    if (!(x.hi & 0x8000000000000000ULL)) { x.hi <<=  1; r +=  1; }
    return r;
}

my_uint128_t shl_128 (my_uint128_t x, uint32_t shift)
{
    my_uint128_t r;
    if (shift > 127) {
        r.lo = 0ULL;
        r.hi = 0ULL;
    } else if (shift > 63) {
        r.lo = 0ULL;
        r.hi = x.lo << (shift - 64);;
    } else if (shift > 0) {
        r.lo = x.lo << shift;
        r.hi = (x.hi << shift) | (x.lo >> (64 - shift));
    } else {
        r = x;
    }
    return r;
}

/* 128/64->64 bit division. Note: Will overflow if x[127:64] >= y */
uint64_t udiv_128_64 (my_uint128_t x, uint64_t y)
{
    uint64_t quot, rem;
    __asm__ (
        "movq %2, %%rax\n\t"
        "movq %3, %%rdx\n\t"
        "divq %4\n\t"
        "movq %%rax, %0\n\t"
        "movq %%rdx, %1\n\t"
        : "=r" (quot), "=r" (rem)
        : "r" (x.lo), "r" (x.hi), "r" (y)
        : "rax", "rdx");
    return quot;
}

/* 64x64->128 bit multiply */
my_uint128_t umul_64_128 (uint64_t a, uint64_t b)
{
    my_uint128_t r;
    __asm__(
        "movq %2, %%rax\n\t"
        "mulq %3\n\t"
        "movq %%rax, %0\n\t"
        "movq %%rdx, %1\n\t"
        : "=r" (r.lo), "=r" (r.hi) 
        : "r" (a), "r" (b) 
        : "rax", "rdx");
    return r;
}

/* macros for multi-word arithmetic */
#define ADDCcc(a,b,cy,t0,t1) (t0=(b)+cy, t1=(a), cy=t0<cy, t0=t0+t1, t1=t0<t1, cy=cy+t1, t0=t0)
#define ADDcc(a,b,cy,t0,t1) (t0=(b), t1=(a), t0=t0+t1, cy=t0<t1, t0=t0)
#define ADDC(a,b,cy,t0,t1) (t0=(b)+cy, t1=(a), t0+t1)
#define SUBCcc(a,b,cy,t0,t1,t2) (t0=(b)+cy, t1=(a), cy=t0<cy, t2=t1<t0, cy=cy+t2, t1-t0)
#define SUBcc(a,b,cy,t0,t1) (t0=(b), t1=(a), cy=t1<t0, t1-t0)
#define SUBC(a,b,cy,t0,t1) (t0=(b)+cy, t1=(a), t1-t0)

my_uint128_t add_128 (my_uint128_t x, my_uint128_t y)
{
    my_uint128_t r;
    uint64_t cy, t0, t1;
    r.lo = ADDcc (x.lo, y.lo, cy, t0, t1);
    r.hi = ADDC (x.hi, y.hi, cy, t0, t1);
    return r;
}

my_uint128_t sub_128 (my_uint128_t x, my_uint128_t y)
{
    my_uint128_t r;
    uint64_t cy, t0, t1;
    r.lo = SUBcc (x.lo, y.lo, cy, t0, t1);
    r.hi = SUBC (x.hi, y.hi, cy, t0, t1);
    return r;   
}

int gt_128 (my_uint128_t x, my_uint128_t y)
{
    return (x.hi == y.hi) ? (x.lo > y.lo) : (x.hi > y.hi);
}

int lt_128 (my_uint128_t x, my_uint128_t y)
{
    return (x.hi == y.hi) ? (x.lo < y.lo) : (x.hi < y.hi);
}

int eq_128 (my_uint128_t x, my_uint128_t y)
{
    return (x.hi == y.hi) && (x.lo == y.lo);
}

int ge_128 (my_uint128_t x, my_uint128_t y)
{
    return gt_128 (x, y) || eq_128 (x, y);
}

int le_128 (my_uint128_t x, my_uint128_t y)
{
    return lt_128 (x, y) || eq_128 (x, y);
}

/* compute average of a and b rounded towards zero, preventing overflow */
uint64_t avg_64 (uint64_t a, uint64_t b)
{
    return (a & b) + ((a ^ b) >> 1);
}

/*
  https://groups.google.com/forum/#!original/comp.lang.c/qFv18ql_WlU/IK8KGZZFJx4J
  From: geo <gmars...@gmail.com>
  Newsgroups: sci.math,comp.lang.c,comp.lang.fortran
  Subject: 64-bit KISS RNGs
  Date: Sat, 28 Feb 2009 04:30:48 -0800 (PST)

  This 64-bit KISS RNG has three components, each nearly
  good enough to serve alone.    The components are:
  Multiply-With-Carry (MWC), period (2^121+2^63-1)
  Xorshift (XSH), period 2^64-1
  Congruential (CNG), period 2^64
*/

static uint64_t kiss64_x = 1234567890987654321ULL;
static uint64_t kiss64_c = 123456123456123456ULL;
static uint64_t kiss64_y = 362436362436362436ULL;
static uint64_t kiss64_z = 1066149217761810ULL;
static uint64_t kiss64_t;

#define MWC64  (kiss64_t = (kiss64_x << 58) + kiss64_c, \
                kiss64_c = (kiss64_x >> 6), kiss64_x += kiss64_t, \
                kiss64_c += (kiss64_x < kiss64_t), kiss64_x)
#define XSH64  (kiss64_y ^= (kiss64_y << 13), kiss64_y ^= (kiss64_y >> 17), \
                kiss64_y ^= (kiss64_y << 43))
#define CNG64  (kiss64_z = 6906969069ULL * kiss64_z + 1234567ULL)
#define KISS64 (MWC64 + XSH64 + CNG64)

static void error_abort (const char * filename, int line, const char *expr)
{
    fprintf (stderr, "\n%s line %d. Assertion failed: %s\n Aborting.\n", 
             filename, line, expr);
    exit (EXIT_FAILURE);
}

#define MY_ASSERT(expr)\
    ((expr)?((void)0):((void)error_abort(__FILE__,__LINE__,#expr)))

#if !BENCHMARK
int main (void)
{
    const my_uint128_t zero = {0ULL, 0ULL}, one = {1ULL, 0ULL};
    my_uint128_t x, arg, t;
    uint64_t res;

#if TEST_MODE == MODE_DECR
    printf ("Test integer square root sequentially; decrementing\n");
#elif TEST_MODE == MODE_INCR
    printf ("Test integer square root sequentially; incrementing\n");
#elif TEST_MODE == MODE_RANDOM
    printf ("Test integer square root using purely random argument\n");
#else
#error unsupported TEST_MODE
#endif // TEST_MODE

    x = zero;
    do {
#if TEST_MODE == MODE_RANDOM
        arg.lo = KISS64;
        arg.hi = KISS64 & 0x7fffffffffffffffULL;
#elif TEST_MODE == MODE_DECR
        arg = sub_128 (zero, x);
#elif TEST_MODE == MODE_INCR
        arg = x;
#endif // TEST_MODE
        res = isqrt_128 (arg);
        /* Check correctness: res * res  must be less than or equal to arg.
           (res + 1) * (res + 1) must be greater than arg.
        */
        t.lo = res;
        t.hi = 0ULL;

        MY_ASSERT ((lt_128 (arg, shl_128 (one, 127)) && 
                    le_128 (umul_64_128 (res, res), arg) &&
                    gt_128 (add_128 (add_128 (add_128 (umul_64_128 (res, res), t), t), one), arg))
                   ||
                   (ge_128 (arg, shl_128 (one, 127)) && 
                    le_128 (umul_64_128 (res, res), arg) && 
                    gt_128 (umul_64_128 (res, res), sub_128 (sub_128 (sub_128 (arg, t), t), one)))
);
        if ((x.lo & 0xffffffULL) == 0) printf ("\r%016llx_%016llx", x.hi, x.lo);
        x = add_128 (x, one);
    } while (!(eq_128 (x, zero)));
    return EXIT_SUCCESS;
}

#else // BENCHMARK

// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

#define N 200000000
int main (void)
{
    my_uint128_t x, zero = {0ULL, 0ULL};
    uint64_t r = 0ULL;
    int i, k;
    double start, stop;

    printf ("Running benchmark (%d iterations), please wait ...\n", N);
    for (k = 0; k < 2; k++) {
        start = second();
        for (i = 0; i < N; i++) {
            x.hi = KISS64;
            x.lo = r ^ x.hi;
            r = isqrt_128 (x);
        }
        stop = second();
    }
    printf ("r=%016llx\n", r);
    printf ("elapsed = %23.16e seconds per isqrt_128\n", (stop-start)/N);
    return EXIT_SUCCESS;        
}

#endif // BENCHMARK

这是我的(周六下午,无论是字面上还是比喻上)代码。没有任何优雅或最佳的主张 - 请拍摄!你会看到 main() 需要109使用我提出的混合 Zimmermann(-Karatsuba(-Newton))/硬件程序的平方根,并使用 gmp 获取它们;它还检查这两种方法在所有测试中是否给出相同的答案 (pseudorandom) 。我的混合程序似乎快了 8.5 倍多一点。

我确信可以比我下面的代码做得更好。哦,是的,当然它是不可移植的(它适用于在 x86-64 上运行的 gcc),但我们暂时不用担心。请注意,我不假设 FPU 是向上、向下还是四舍五入到最接近的可表示数字。

#include <stdlib.h>
#include <stdio.h>
#include <gmpxx.h>
#include <math.h>
#include <time.h>
#include <inttypes.h>

typedef unsigned __int128 I;

#define LOG2(X) ((unsigned) (8*sizeof(ulong) - 1 - __builtin_clzl((X))))
/* log2 rounded down; works only in gcc */


inline long log2l(I x)
{
 ulong high, low;

 high = x>>64;
 if(high)
   return LOG2(high)+64;
 else {
  low = x&(~((ulong) 0));
  return LOG2(x);
 }
}

ulong sqrt128(I n)
/* returns floor of sqrt(n) */
/* works for 0<=n<2^{125} */
{
 int flag=0, k, be;
 ulong head, q, a1, num, den, s,sp,sp2,rp;
 I r;

 k = log2l(n); be=k/4+1;
 if((k%4)<2) {
   flag = 1;
   n <<= 2;
 }
 head = n>>(2*be);
 a1 = n - (head<<(2*be));
 a1 >>= be;

 sp = sqrtl(head); sp2 = sp*sp;
 if(head>sp2)
   rp = head-sp2;
 else {
   rp=(head+2*sp-1)-sp2;
   sp--;
 }

 /* So: sp = integer part of sqrt(head), rp = head-sp*sp */
 num = ((rp<<be)+a1); den = 2*sp;
 q = num/den;
 s = (sp<<be) + q;

 if(((I) s)*((I) s) > n)
   s--;

 if(flag) 
   s >>= 1;

 return s;
}

const I twop32 = ((I) 4294967296)*((I) 4294967296);
inline mpz_class mpzify(I x)
{
  mpz_class twp32 = ((mpz_class) 4294967296)*((mpz_class) 4294967296);

  return mpz_class((ulong) (x/twop32))*twp32+mpz_class((ulong) (x%twop32));
}

main()
{
  ulong A0,A1,A2,A3,i,j,B;
  I n;
  time_t t0,t1,t2,t3;
  mpz_class Qz;

  t0 = time(NULL);

  srand48(t0);
  for(i=0; i<1000000; i++) {
    A0 = lrand48(); A1 = lrand48();
    A2 = lrand48(); A3 = lrand48()>>3;
    n = (((I) (A2 + (A3<<32)))<<64) + A0 + (A1<<32);
    for(j=0; j<1000; j++)
      B =  sqrt128(n+j);
  }
  t1 = time(NULL);
  printf("Wall time for homebrewed sqrt: %ld\n",t1-t0);

  srand48(t0);
  for(i=0; i<1000000; i++) {
    A0 = lrand48(); A1 = lrand48();
    A2 = lrand48(); A3 = lrand48()>>3;
    n = (((I) (A2 + (A3<<32)))<<64) + A0 + (A1<<32);
    for(j=0; j<1000; j++) {
      Qz =  sqrt(mpzify(n+j));
      B = Qz.get_ui();
  }
  t2 = time(NULL);
  printf("Wall time for gmp sqrt: %ld\n",t2-t1);

  srand48(t0);
  for(i=0; i<1000000; i++) {
    A0 = lrand48(); A1 = lrand48();
    A2 = lrand48(); A3 = lrand48()>>3;
    n = (((I) (A2 + (A3<<32)))<<64) + A0 + (A1<<32);
    for(j=0; j<1000; j++)
      B=A0;
  }
  t3=time(NULL);
  printf("Wall time just for looping: %ld\n",t3-t2);

  printf("Checking that the two algorithms always give the same result...\n");
  srand48(t0);
  for(i=0; i<1000000; i++) {
    A0 = lrand48(); A1 = lrand48();
    A2 = lrand48(); A3 = lrand48()>>3;
    n = (((I) (A2 + (A3<<32)))<<64) + A0 + (A1<<32);
    for(j=0; j<1000; j++) {
      Qz =  sqrt(mpzify(n+j));
      B = Qz.get_ui();
      if(B!=sqrt128(n+j))
        printf("Unfortunately not! %lu %lu\n",B,sqrt128(n+j));
    }
  }
  printf("Done.\n");
}