Tcl Source Code

Artifact [c9ba1da724]
Login

Artifact c9ba1da724d5e2d6fd69bc61d2e7dc5cc3c4d8ba:

Attachment "bn_mp_bz_div.c" to ticket [2814286fff] added by kennykb 2009-06-30 06:17:22.
#include <stdio.h>
#include <tommath.h>
#ifdef BN_MP_BZ_DIV_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis
 *
 * LibTomMath is a library that provides multiple-precision
 * integer arithmetic as well as number theoretic functionality.
 *
 * The library was designed directly after the MPI library by
 * Michael Fromberger but has been written from scratch with
 * additional optimizations in place.
 *
 * The library is free for all purposes without any express
 * guarantee it works.
 *
 * Tom St Denis, [email protected], http://math.libtomcrypt.com
 *
 * This module:  Kevin B. Kenny, [email protected]
 */

/*
#define TRACE_BZ_DIV
#define TRACE_BZ_2N_1N
#define TRACE_BZ_3N_2N
*/

#if defined(TRACE_BZ_DIV) || defined(TRACE_BZ_2N_1N) || defined(TRACE_BZ_3N_2N)
static void
traceprint(const char* func, const char* var, const mp_int* val) {
    int i, j, did;
    const char* sep = "";
    fprintf(stderr, "%s: %s = ", func, var);
    if (val->sign == MP_NEG) {
	fprintf(stderr, "-( ");
    }
    did = 0;
    for (i = val->used * DIGIT_BIT - 1; i >= 0; --i) {
	if (val->dp[i / DIGIT_BIT] & (1 << (i % DIGIT_BIT))) {
	    did = 1;
	    for (j = i-1; j >= 0; --j) {
		if ((val->dp[j / DIGIT_BIT] & (1 << (j % DIGIT_BIT))) == 0) {
		    break;
		}
	    }
	    ++j;
	    if (j == i) {
		fprintf(stderr, "%s2**%d", sep, i);
	    } else if (j == i + 1) {
		fprintf(stderr, "%s2**%d + 2**%d", sep, i, j);
	    } else {
		fprintf(stderr, "%s2**%d - 2**%d", sep, i+1, j);
	    }
	    sep = " + ";
	    i = j;
	}
    }
    if (!did) {
	fprintf(stderr, "0");
    }
    if (val->sign == MP_NEG) {
	fprintf(stderr, " )");
    }
    fprintf(stderr, "\n");
    fflush(stderr);
}
#endif

/* Static functions defined in this file */

static int mp_bz_div_2n_1n(int n, mp_int* a, mp_int* b, mp_int* c, mp_int* d);
static int mp_bz_div_3n_2n(int n, mp_int* a, mp_int* b, mp_int* c, mp_int* d);

/*
 * mp_bz_div --
 *
 *	Division in Karatsuba time: a/b == c remainder d, or a == bc + d
 *
 * Source: Christoph Burnikel, Joachim Ziegler. "Fast Recursive Division."
 *	   Forschungsbericht MPI I 98 1 022, Max-Plank-Institut fuer
 *	   Informatik, Saarbruecken, Germany (October 1998)
 *	   http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.47.565
 */

int
mp_bz_div(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
{

    int signum;			/* Sign of the quotient and remainder */
    int k;			/* Number of recursive steps we will need */
    int j;			/* Number of mp_digits in the smallest step */
    int m;			/* 2**k * BZ_DIV_CUTOFF */ 
    int n;			/* Number of mp_digits in the current step */
    int sigma;			/* Number of bit shifts needed to normalize
				 * the divisor */
    int t;
    int i;
    mp_int aa, bb, cc, r, q;
    mp_digit dig;
    int res;

#ifdef TRACE_BZ_DIV
    fprintf(stderr, "mp_bz_div: a->used=%d, b->used=%d\n",
	    a->used, b->used);
    traceprint("mp_bz_div", "a", a);
    traceprint("mp_bz_div", "b", b);
#endif

    /* Algorithm 3. D_{r/s} */

    signum = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
    
    /* Determine smallest width that will accommodate the divisor. */

    m = BZ_DIV_CUTOFF;
    k = 0;
    while (m <= b->used) {
	m <<= 1;
	++k;
    }
    j = (b->used + m - 1) / m;
    n = j * m;
    
    /* 
     * Normalize the divisor in an n-mp_digit integer. Shift the dividend
     * left by the same amount.
     */

    sigma = MP_DIGIT_BIT * (n - b->used);
    dig = b->dp[b->used - 1];
    while (dig < (1 << (MP_DIGIT_BIT-1))) {
	++sigma;
	dig <<= 1;
    }

    /* Allocate memory for shifted operands and for results */

    if ((res = mp_init_multi(&aa, &bb, NULL)) != MP_OKAY) {
	return res;
    }

    /* Shift b left to normalize it.  Shift a left by the same amount. */

    if ((res = mp_mul_2d(a, sigma, &aa)) != MP_OKAY) {
	goto LBL_BB;
    }
    aa.sign = MP_ZPOS;
    if ((res = mp_mul_2d(b, sigma, &bb)) != MP_OKAY) {
	goto LBL_BB;
    }
    bb.sign = MP_ZPOS;

    /* 
     * bb is now n mp_digits long, and its most significant bit is 1.
     * Find out how many n-digit pieces there are in aa such that
     * the most significant bit of the most significant piece is *not* 1.
     */

    t = (aa.used + n - 1) / n;
    if ((aa.used % n == 0)
	&& ((aa.dp[aa.used-1] & (1 << (DIGIT_BIT - 1))) != 0)) {
	++t;
    }
    if (t < 2) {
	t = 2;
    }

#ifdef TRACE_BZ_DIV
    fprintf(stderr, "mp_bz_div: t = %d, n = %d sigma = %d\n", t, n, sigma);
    traceprint("mp_bz_div", "shifted dividend aa", &aa);
    traceprint("mp_bz_div", "shifted divisor bb", &bb);
#endif
    /* Allocate memory for temps */

    if ((res = mp_init_size(&cc, t * n)) != MP_OKAY) {
	goto LBL_BB;
    }
    if ((res = mp_init_size(&r, 2 * n)) != MP_OKAY) {
	goto LBL_CC;
    }
    if ((res = mp_init_size(&q, t * n)) != MP_OKAY) {
	goto LBL_R;
    }

    /* 
     * Begin the division with the most significant 2*n mp_digits of aa,
     * when a is viewed as being t*n digits long.  In other words,
     * shift A to the right by (t-2)*n mp_digits 
     */

    for (i = (t-1) * n, j = 0; i < aa.used; ++i, ++j) {
	r.dp[j] = aa.dp[i];
    }
    r.used = j;

    /* Do 'school short division' on 'digits' that are each n mp_digits long. */

    cc.used = 0;
    t-=2;
    while (t >= 0) {

	/* Shift the next n mp_digits into r */

	if ((res = mp_grow(&r, r.used + n)) != MP_OKAY) {
	    goto LBL_Q;
	}
	for (i = 0, j = n; i < r.used; ++i, ++j) {
	    r.dp[j] = r.dp[i];
	}
	r.used = j;
	for (j = 0, i = t * n; j < n && i < aa.used; ++i, ++j) {
	    r.dp[j] = aa.dp[i];
	}
	while (j < n) {
	    r.dp[j] = 0;
	}

	/* Perform one division step */

#ifdef TRACE_BZ_DIV
	fprintf(stderr, "mp_bz_div: part %d\n", t);
	traceprint("mp_bz_div", "dividend r", &r);
	traceprint("mp_bz_div", "divisor bb", &bb);
#endif
	if ((res = mp_bz_div_2n_1n(n, &r, &bb, &q, &r)) != MP_OKAY) {
	    goto LBL_Q;
	}
#ifdef TRACE_BZ_DIV
	traceprint("mp_bz_div", "quotient q", &q);
	traceprint("mp_bz_div", "remainder r", &r);
#endif

	/* Accumulate the quotient digits in cc */

	for (j = t * n, i = 0; i < q.used; ++i, ++j) {
	    cc.dp[j] = q.dp[i];
	}
	for (; i < n; ++i, ++j) {
	    cc.dp[j] = 0;
	}
	if (cc.used == 0) {
	    cc.used = j;
	}
#ifdef TRACE_BZ_DIV
	traceprint("mp_bz_div", "partial quotient cc", &cc);
#endif
	--t;
    }

    /* 
     * The loop above has put the quotient in cc and remainder in r, but the
     * remainder is shifted left by sigma bits.  Shift it back to the right.
     */

    if ((res = mp_div_2d(&r, sigma, &r, NULL)) != MP_OKAY) {
	goto LBL_Q;
    }

    /* Store results for the caller */

    mp_clamp(&cc);
    cc.sign = r.sign = signum;
    if (c != NULL) {
	mp_exch(&cc, c);
    }
    if (d != NULL) {
	mp_exch(&r, d);
    }
    res = MP_OKAY;

 LBL_Q:
    mp_clear(&q);
 LBL_R:
    mp_clear(&r);
 LBL_CC:
    mp_clear(&cc);
 LBL_BB:
    mp_clear(&bb);
    mp_clear(&aa);
    return res;
}

/*
 * mp_bz_div_2n_1n --
 *
 *	Recursive division of a 2n-mp_digit number by an n-mp_digit number.
 *
 * This procedure corresponds to Algorithm 1 of the Burnikel-Ziegler paper.
 */

static int
mp_bz_div_2n_1n(int n,		/* Length of the numbers */
		mp_int* a,	/* Dividend */
		mp_int* b,	/* Divisor */
		mp_int* c,	/* Quotient */
		mp_int* d	/* Remainder */
) {
    mp_int q1, q2, r;
    int res;
    int i, j;
    int halfn;

    if ((res = mp_init_multi(&q1, &q2, &r, NULL)) != MP_OKAY) {
	goto LBL_R;
    }

    if (n % 1 || n <= BZ_DIV_CUTOFF) {
	if ((res = mp_div(a, b, &q1, &r)) != MP_OKAY) {
	    goto LBL_R;
	}
	if (c != NULL) {
	    mp_exch(&q1, c);
	}
	if (d != NULL) {
	    mp_exch(&r, d);
	}
	goto LBL_R;
    }

#ifdef TRACE_BZ_2N_1N
    fprintf(stderr, "mp_bz_div_2n_1n: n=%d, a->used=%d, b->used=%d\n", n,
	    a->used, b->used);
    traceprint("mp_bz_div_2n_1n", "a", a);
    traceprint("mp_bz_div_2n_1n", "b", b);
#endif

    halfn = n / 2;

    /* Extract the most significant 3/4 of the digits from A. */

    if (a->used >= halfn) {
	if ((res = mp_grow(&r, a->used-halfn)) != MP_OKAY) {
	    goto LBL_R;
	}
    }
    for (i = 0, j = halfn; j < a->used; ++i, ++j) {
	r.dp[i] = a->dp[j];
    }
    r.used = i;

    /* Generate the most significant digits of the result in q */

#ifdef TRACE_BZ_2N_1N
    fprintf(stderr, "mp_bz_div_2n_1n: first division n=%d, r->used=%d, b->used=%d\n", n,
	    r.used, b->used);
    traceprint("mp_bz_div_2n_1n", "r", &r);
    traceprint("mp_bz_div_2n_1n", "b", b);
#endif
    if ((res = mp_bz_div_3n_2n(halfn, &r, b, &q1, &r)) != MP_OKAY) {
	goto LBL_R;
    }
#ifdef TRACE_BZ_2N_1N
    fprintf(stderr, "mp_bz_div_2n_1n: first division returns n=%d\n", n);
    traceprint("mp_bz_div_2n_1n", "q1", &q1);
    traceprint("mp_bz_div_2n_1n", "r", &r);

    /* Shift the remainder into aa and bring in the rest of the dividend */

    fprintf(stderr, "mp_bz_div_2n_1n: n=%d halfn=%d r.alloc=%d need %d\n",
	    n, halfn, r.alloc, r.used + halfn);
    fflush(stderr);
#endif
    if ((res = mp_grow(&r, r.used + halfn)) != MP_OKAY) {
	goto LBL_R;
    }
#ifdef TRACE_BZ_2N_1N
    traceprint("mp_bz_div_2n_1n", "a", a);
#endif
    for (i = r.used-1; i >= 0; --i) {
	r.dp[i + halfn] = r.dp[i];
    }
    for (i = 0; i < halfn && i < a->used; ++i) {
	r.dp[i] = a->dp[i];
    }
    while (i < halfn) {
	r.dp[i++] = 0;
    }
    r.used += halfn;

    /* Divide to get the least significant digits */

#ifdef TRACE_BZ_2N_1N
    fprintf(stderr, "mp_bz_div_2n_1n: second division n=%d, r.used=%d, b->used=%d\n", n,
	    r.used, b->used);
    traceprint("mp_bz_div_2n_1n", "r", &r);
    traceprint("mp_bz_div_2n_1n", "b", b);
#endif
    mp_bz_div_3n_2n(halfn, &r, b, &q2, &r);
#ifdef TRACE_BZ_2N_1N
    traceprint("mp_bz_div_2n_1n", "q2", &q2);
    traceprint("mp_bz_div_2n_1n", "r", &r);
#endif

    /* Compose the quotient */

#ifdef TRACE_BZ_2N_1N
    fprintf(stderr, "mp_bz_div_2n_1n: combining quotients q1.used %d q2.used %d halfn %d.\n",
	    q1.used, q2.used, halfn);
    traceprint("mp_bz_div_2n_1n", "q1", &q1);
    traceprint("mp_bz_div_2n_1n", "q2", &q2);
#endif
    if ((res = mp_grow(&q2, q1.used+halfn)) != MP_OKAY) {
	goto LBL_R;
    }
    for (i = q2.used; i < halfn; ++i) {
	q2.dp[i] = 0;
    }
    for (j = 0; j < q1.used; ++j) {
	q2.dp[i++] = q1.dp[j];
    }
    q2.used = i;

#ifdef TRACE_BZ_2N_1N_ENTRY_EXIT
    fprintf(stderr, "mp_bz_div_2n_1n: n=%d, q2.used=%d, r.used=%d\n", n,
	    q2.used, r.used);
    traceprint("mp_bz_div_2n_1n", "q2", &q2);
    traceprint("mp_bz_div_2n_1n", "r", &r);
#endif

    if (c != NULL) {
	mp_exch(c, &q2);
    }
    if (d != NULL) {
	mp_exch(d, &r);
    }

 LBL_R:
    mp_clear_multi(&r, &q2, &q1, NULL);
    return res;
}

/*
 * mp_bz_div_3n_2n --
 *
 *	Divide a 3n-digit number by a 2n-digit number yielding an n-digit
 *	quotient and a 2n-digit remainder
 *
 * This procedure corresponds to Algorithm 2 of the Burnikel-Ziegler paper.
 */

static int
mp_bz_div_3n_2n(int n,		/* Length of the numbers */
		mp_int* a,	/* Dividend */
		mp_int* b,	/* Divisor */
		mp_int* c,	/* Quotient */
		mp_int* d	/* Remainder */
) {
    mp_int aa, bb;
    mp_int Qhat, Rhat, R1;
    int i, j;
    int altb;
    int res;

    if ((res = mp_init_multi(&Qhat, &R1, &Rhat, NULL)) != MP_OKAY) {
	return res;
    }

#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n: n = %d, a->used = %d, b->used = %d\n",
	    n, a->used, b->used); fflush(stderr);
    traceprint ("mp_bz_div_3n_2n", "a", a);
    traceprint ("mp_bz_div_3n_2n", "b", b);
#endif

    /* 
     * 1.-2.
     * Extract the most significant 2n digits of a and the most significant n
     * digits of b.
     */

    if ((res = mp_init_size(&aa, 2 * n)) != MP_OKAY) {
	goto LBL_RHAT;
    }
    if ((res = mp_init_size(&bb, n)) != MP_OKAY) {
	goto LBL_AA;
    }
    for (i = n; i < a->used; ++i) {
	aa.dp[i-n] = a->dp[i];
    }
    aa.used = i-n;
    for (i = n; i < b->used; ++i) {
	bb.dp[i-n] = b->dp[i];
    }
    bb.used = i-n;

#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n: aa.used = %d bb.used = %d\n", aa.used, bb.used); fflush(stderr);
    traceprint ("mp_bz_div_3n_2n", "aa", &aa);
    traceprint ("mp_bz_div_3n_2n", "bb", &bb);
#endif

    /* At this point, aa is [A1,A2] and bb is B1. Compare A1 and B1. */

    altb = 0;
    if (aa.used > bb.used + n) {
	altb = 0;
    } else if (bb.used + n > aa.used) {
	altb = 1;
    } else {
	for (i = aa.used-1, j = bb.used-1; j >= 0; i--, j--) {
	    if (aa.dp[i] < bb.dp[j]) {
		altb = 1;
		break;
	    } else if (bb.dp[j] < aa.dp[i]) {
		altb = 0;
		break;
	    }
	}
    }
#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n: A1<B1? %s\n",
	    (altb ? "yes" : "no")); fflush(stderr);
#endif

    /* 
     * 3. If A1 < B1, compute c = (A1,A2)/B1 with remainder d recursively. 
     * Otherwise, let c = 2**(n*DIGIT_BIT)-1, and d = (A1,A2)-(B1,0)+(0,B1).
     */

    if (altb) {
#ifdef TRACE_BZ_3N_2N
	fprintf(stderr, "mp_bz_div_3n_2n: recursing\n"); fflush(stderr);
#endif
	if ((res = mp_bz_div_2n_1n(n, &aa, &bb, &Qhat, &R1)) != MP_OKAY) {
	    goto LBL_BB;
	}
    } else {
#ifdef TRACE_BZ_3N_2N
	fprintf(stderr, "mp_bz_div_3n_2n: big quotient\n");
#endif
	/* Qhat=(beta**n) - 1 */
	if ((res = mp_grow(&Qhat, n)) != MP_OKAY) {
	    goto LBL_BB;
	}
	for (i = 0; i < n; ++i) {
	    Qhat.dp[i] = MP_DIGIT_MAX;
	}
	Qhat.used = n;
#ifdef TRACE_BZ_3N_2N
	traceprint("mp_bz_div_3n_2n", "Qhat", &Qhat);
#endif

	/* R1 = [A1, A2] + [0, B1] - [B1, 0] */
	if ((res = mp_add(&aa, &bb, &R1)) != MP_OKAY) {
	    goto LBL_BB;
	}
	if ((res = mp_lshd(&bb, n)) != MP_OKAY) {
	    goto LBL_BB;
	}
	if ((res = mp_sub(&R1, &bb, &R1)) != MP_OKAY) {
	    goto LBL_BB;
	}
#ifdef TRACE_BZ_3N_2N
	traceprint("mp_bz_div_3n_2n", "R1", &R1);
#endif
    }

#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n (n = %d): first division returns:\n", n);
    traceprint("mp_bz_div_3n_2n", "[A1,A2]", &aa);
    traceprint("mp_bz_div_3n_2n", "[B1]", &bb);
    traceprint("mp_bz_div_3n_2n", "Qhat", &Qhat);
    traceprint("mp_bz_div_3n_2n", "R1", &R1);
#endif

    /* Get the less significant digits of b. */

#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "bb's alloc is %d, need %d\n", bb.alloc, n); fflush(stderr);
#endif
    for (i = 0; i < n; ++i) {
	bb.dp[i] = b->dp[i];
    }
    while (i < bb.used) {
	bb.dp[i++] = 0;
    }
    bb.used = n;
    mp_clamp(&bb);

    /* 4. D = Qhat * B2 */

    if ((res = mp_mul(&Qhat, &bb, &aa)) != MP_OKAY) {
	goto LBL_BB;
    }
#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n (n = %d): trial multiplication:\n", n);
    traceprint("mp_bz_div_3n_2n", "multiplicand Qhat", &Qhat);
    traceprint("mp_bz_div_3n_2n", "multiplier B2", &bb);
    traceprint("mp_bz_div_3n_2n", "product", &aa);
#endif

    /* 
     * 5. Rhat = R1 * beta**n + A3 - D 
     * (Paper has a typo, it has A4 at step 5) 
     */

#ifdef TRACE_BZ_3N_2N
    fprintf(stderr, "mp_bz_div_3n_2n: making rhat\n");
    traceprint("mp_bz_div_3n_2n", "R1", &R1);
    traceprint("mp_bz_div_3n_2n", "a", a);
    traceprint("mp_bz_div_3n_2n", "D", &aa);
#endif

    if ((res = mp_lshd(&R1, n)) != MP_OKAY) {
	goto LBL_BB;
    }
    if ((res = mp_mod_2d(a, MP_DIGIT_BIT * n, &Rhat)) != MP_OKAY) {
	goto  LBL_BB;
    }
    if ((res = mp_add(&Rhat, &R1, &Rhat)) != MP_OKAY) {
	goto LBL_BB;
    }
#ifdef TRACE_BZ_3N_2N
    traceprint("mp_bz_div_3n_2n", "[R1,A3]", &Rhat);
#endif
    if ((res = mp_sub(&Rhat, &aa, &Rhat)) != MP_OKAY) {
	goto LBL_BB;
    }
#ifdef TRACE_BZ_3N_2N
    traceprint("mp_bz_div_3n_2n", "[R1,A3] - D", &Rhat);
#endif

    /* 6. While Rhat < 0, Rhat += B and Qhat -= 1 */

    while (Rhat.sign == MP_NEG) {
	if ((res = mp_add(&Rhat, b, &Rhat)) != MP_OKAY) {
	    goto LBL_BB;
	}
	if ((res = mp_sub_d(&Qhat, 1, &Qhat)) != MP_OKAY) {
	    goto LBL_BB;
	}
    }
    if (c != NULL) {
	mp_exch(c,&Qhat);
    }
    if (d != NULL) {
	mp_exch(d,&Rhat);
    }

 LBL_BB:
    mp_clear(&bb);
 LBL_AA:
    mp_clear(&aa);
 LBL_RHAT:
    mp_clear_multi(&Rhat, &R1, &Qhat, NULL);
    return res;
}

#endif