Attachment "bn_mp_bz_div.c" to
ticket [2814286fff]
added by
kennykb
2009-06-30 06:17:22.
#include <stdio.h>
#include <tommath.h>
#ifdef BN_MP_BZ_DIV_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis
*
* LibTomMath is a library that provides multiple-precision
* integer arithmetic as well as number theoretic functionality.
*
* The library was designed directly after the MPI library by
* Michael Fromberger but has been written from scratch with
* additional optimizations in place.
*
* The library is free for all purposes without any express
* guarantee it works.
*
* Tom St Denis, [email protected], http://math.libtomcrypt.com
*
* This module: Kevin B. Kenny, [email protected]
*/
/*
#define TRACE_BZ_DIV
#define TRACE_BZ_2N_1N
#define TRACE_BZ_3N_2N
*/
#if defined(TRACE_BZ_DIV) || defined(TRACE_BZ_2N_1N) || defined(TRACE_BZ_3N_2N)
static void
traceprint(const char* func, const char* var, const mp_int* val) {
int i, j, did;
const char* sep = "";
fprintf(stderr, "%s: %s = ", func, var);
if (val->sign == MP_NEG) {
fprintf(stderr, "-( ");
}
did = 0;
for (i = val->used * DIGIT_BIT - 1; i >= 0; --i) {
if (val->dp[i / DIGIT_BIT] & (1 << (i % DIGIT_BIT))) {
did = 1;
for (j = i-1; j >= 0; --j) {
if ((val->dp[j / DIGIT_BIT] & (1 << (j % DIGIT_BIT))) == 0) {
break;
}
}
++j;
if (j == i) {
fprintf(stderr, "%s2**%d", sep, i);
} else if (j == i + 1) {
fprintf(stderr, "%s2**%d + 2**%d", sep, i, j);
} else {
fprintf(stderr, "%s2**%d - 2**%d", sep, i+1, j);
}
sep = " + ";
i = j;
}
}
if (!did) {
fprintf(stderr, "0");
}
if (val->sign == MP_NEG) {
fprintf(stderr, " )");
}
fprintf(stderr, "\n");
fflush(stderr);
}
#endif
/* Static functions defined in this file */
static int mp_bz_div_2n_1n(int n, mp_int* a, mp_int* b, mp_int* c, mp_int* d);
static int mp_bz_div_3n_2n(int n, mp_int* a, mp_int* b, mp_int* c, mp_int* d);
/*
* mp_bz_div --
*
* Division in Karatsuba time: a/b == c remainder d, or a == bc + d
*
* Source: Christoph Burnikel, Joachim Ziegler. "Fast Recursive Division."
* Forschungsbericht MPI I 98 1 022, Max-Plank-Institut fuer
* Informatik, Saarbruecken, Germany (October 1998)
* http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.47.565
*/
int
mp_bz_div(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
{
int signum; /* Sign of the quotient and remainder */
int k; /* Number of recursive steps we will need */
int j; /* Number of mp_digits in the smallest step */
int m; /* 2**k * BZ_DIV_CUTOFF */
int n; /* Number of mp_digits in the current step */
int sigma; /* Number of bit shifts needed to normalize
* the divisor */
int t;
int i;
mp_int aa, bb, cc, r, q;
mp_digit dig;
int res;
#ifdef TRACE_BZ_DIV
fprintf(stderr, "mp_bz_div: a->used=%d, b->used=%d\n",
a->used, b->used);
traceprint("mp_bz_div", "a", a);
traceprint("mp_bz_div", "b", b);
#endif
/* Algorithm 3. D_{r/s} */
signum = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
/* Determine smallest width that will accommodate the divisor. */
m = BZ_DIV_CUTOFF;
k = 0;
while (m <= b->used) {
m <<= 1;
++k;
}
j = (b->used + m - 1) / m;
n = j * m;
/*
* Normalize the divisor in an n-mp_digit integer. Shift the dividend
* left by the same amount.
*/
sigma = MP_DIGIT_BIT * (n - b->used);
dig = b->dp[b->used - 1];
while (dig < (1 << (MP_DIGIT_BIT-1))) {
++sigma;
dig <<= 1;
}
/* Allocate memory for shifted operands and for results */
if ((res = mp_init_multi(&aa, &bb, NULL)) != MP_OKAY) {
return res;
}
/* Shift b left to normalize it. Shift a left by the same amount. */
if ((res = mp_mul_2d(a, sigma, &aa)) != MP_OKAY) {
goto LBL_BB;
}
aa.sign = MP_ZPOS;
if ((res = mp_mul_2d(b, sigma, &bb)) != MP_OKAY) {
goto LBL_BB;
}
bb.sign = MP_ZPOS;
/*
* bb is now n mp_digits long, and its most significant bit is 1.
* Find out how many n-digit pieces there are in aa such that
* the most significant bit of the most significant piece is *not* 1.
*/
t = (aa.used + n - 1) / n;
if ((aa.used % n == 0)
&& ((aa.dp[aa.used-1] & (1 << (DIGIT_BIT - 1))) != 0)) {
++t;
}
if (t < 2) {
t = 2;
}
#ifdef TRACE_BZ_DIV
fprintf(stderr, "mp_bz_div: t = %d, n = %d sigma = %d\n", t, n, sigma);
traceprint("mp_bz_div", "shifted dividend aa", &aa);
traceprint("mp_bz_div", "shifted divisor bb", &bb);
#endif
/* Allocate memory for temps */
if ((res = mp_init_size(&cc, t * n)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_init_size(&r, 2 * n)) != MP_OKAY) {
goto LBL_CC;
}
if ((res = mp_init_size(&q, t * n)) != MP_OKAY) {
goto LBL_R;
}
/*
* Begin the division with the most significant 2*n mp_digits of aa,
* when a is viewed as being t*n digits long. In other words,
* shift A to the right by (t-2)*n mp_digits
*/
for (i = (t-1) * n, j = 0; i < aa.used; ++i, ++j) {
r.dp[j] = aa.dp[i];
}
r.used = j;
/* Do 'school short division' on 'digits' that are each n mp_digits long. */
cc.used = 0;
t-=2;
while (t >= 0) {
/* Shift the next n mp_digits into r */
if ((res = mp_grow(&r, r.used + n)) != MP_OKAY) {
goto LBL_Q;
}
for (i = 0, j = n; i < r.used; ++i, ++j) {
r.dp[j] = r.dp[i];
}
r.used = j;
for (j = 0, i = t * n; j < n && i < aa.used; ++i, ++j) {
r.dp[j] = aa.dp[i];
}
while (j < n) {
r.dp[j] = 0;
}
/* Perform one division step */
#ifdef TRACE_BZ_DIV
fprintf(stderr, "mp_bz_div: part %d\n", t);
traceprint("mp_bz_div", "dividend r", &r);
traceprint("mp_bz_div", "divisor bb", &bb);
#endif
if ((res = mp_bz_div_2n_1n(n, &r, &bb, &q, &r)) != MP_OKAY) {
goto LBL_Q;
}
#ifdef TRACE_BZ_DIV
traceprint("mp_bz_div", "quotient q", &q);
traceprint("mp_bz_div", "remainder r", &r);
#endif
/* Accumulate the quotient digits in cc */
for (j = t * n, i = 0; i < q.used; ++i, ++j) {
cc.dp[j] = q.dp[i];
}
for (; i < n; ++i, ++j) {
cc.dp[j] = 0;
}
if (cc.used == 0) {
cc.used = j;
}
#ifdef TRACE_BZ_DIV
traceprint("mp_bz_div", "partial quotient cc", &cc);
#endif
--t;
}
/*
* The loop above has put the quotient in cc and remainder in r, but the
* remainder is shifted left by sigma bits. Shift it back to the right.
*/
if ((res = mp_div_2d(&r, sigma, &r, NULL)) != MP_OKAY) {
goto LBL_Q;
}
/* Store results for the caller */
mp_clamp(&cc);
cc.sign = r.sign = signum;
if (c != NULL) {
mp_exch(&cc, c);
}
if (d != NULL) {
mp_exch(&r, d);
}
res = MP_OKAY;
LBL_Q:
mp_clear(&q);
LBL_R:
mp_clear(&r);
LBL_CC:
mp_clear(&cc);
LBL_BB:
mp_clear(&bb);
mp_clear(&aa);
return res;
}
/*
* mp_bz_div_2n_1n --
*
* Recursive division of a 2n-mp_digit number by an n-mp_digit number.
*
* This procedure corresponds to Algorithm 1 of the Burnikel-Ziegler paper.
*/
static int
mp_bz_div_2n_1n(int n, /* Length of the numbers */
mp_int* a, /* Dividend */
mp_int* b, /* Divisor */
mp_int* c, /* Quotient */
mp_int* d /* Remainder */
) {
mp_int q1, q2, r;
int res;
int i, j;
int halfn;
if ((res = mp_init_multi(&q1, &q2, &r, NULL)) != MP_OKAY) {
goto LBL_R;
}
if (n % 1 || n <= BZ_DIV_CUTOFF) {
if ((res = mp_div(a, b, &q1, &r)) != MP_OKAY) {
goto LBL_R;
}
if (c != NULL) {
mp_exch(&q1, c);
}
if (d != NULL) {
mp_exch(&r, d);
}
goto LBL_R;
}
#ifdef TRACE_BZ_2N_1N
fprintf(stderr, "mp_bz_div_2n_1n: n=%d, a->used=%d, b->used=%d\n", n,
a->used, b->used);
traceprint("mp_bz_div_2n_1n", "a", a);
traceprint("mp_bz_div_2n_1n", "b", b);
#endif
halfn = n / 2;
/* Extract the most significant 3/4 of the digits from A. */
if (a->used >= halfn) {
if ((res = mp_grow(&r, a->used-halfn)) != MP_OKAY) {
goto LBL_R;
}
}
for (i = 0, j = halfn; j < a->used; ++i, ++j) {
r.dp[i] = a->dp[j];
}
r.used = i;
/* Generate the most significant digits of the result in q */
#ifdef TRACE_BZ_2N_1N
fprintf(stderr, "mp_bz_div_2n_1n: first division n=%d, r->used=%d, b->used=%d\n", n,
r.used, b->used);
traceprint("mp_bz_div_2n_1n", "r", &r);
traceprint("mp_bz_div_2n_1n", "b", b);
#endif
if ((res = mp_bz_div_3n_2n(halfn, &r, b, &q1, &r)) != MP_OKAY) {
goto LBL_R;
}
#ifdef TRACE_BZ_2N_1N
fprintf(stderr, "mp_bz_div_2n_1n: first division returns n=%d\n", n);
traceprint("mp_bz_div_2n_1n", "q1", &q1);
traceprint("mp_bz_div_2n_1n", "r", &r);
/* Shift the remainder into aa and bring in the rest of the dividend */
fprintf(stderr, "mp_bz_div_2n_1n: n=%d halfn=%d r.alloc=%d need %d\n",
n, halfn, r.alloc, r.used + halfn);
fflush(stderr);
#endif
if ((res = mp_grow(&r, r.used + halfn)) != MP_OKAY) {
goto LBL_R;
}
#ifdef TRACE_BZ_2N_1N
traceprint("mp_bz_div_2n_1n", "a", a);
#endif
for (i = r.used-1; i >= 0; --i) {
r.dp[i + halfn] = r.dp[i];
}
for (i = 0; i < halfn && i < a->used; ++i) {
r.dp[i] = a->dp[i];
}
while (i < halfn) {
r.dp[i++] = 0;
}
r.used += halfn;
/* Divide to get the least significant digits */
#ifdef TRACE_BZ_2N_1N
fprintf(stderr, "mp_bz_div_2n_1n: second division n=%d, r.used=%d, b->used=%d\n", n,
r.used, b->used);
traceprint("mp_bz_div_2n_1n", "r", &r);
traceprint("mp_bz_div_2n_1n", "b", b);
#endif
mp_bz_div_3n_2n(halfn, &r, b, &q2, &r);
#ifdef TRACE_BZ_2N_1N
traceprint("mp_bz_div_2n_1n", "q2", &q2);
traceprint("mp_bz_div_2n_1n", "r", &r);
#endif
/* Compose the quotient */
#ifdef TRACE_BZ_2N_1N
fprintf(stderr, "mp_bz_div_2n_1n: combining quotients q1.used %d q2.used %d halfn %d.\n",
q1.used, q2.used, halfn);
traceprint("mp_bz_div_2n_1n", "q1", &q1);
traceprint("mp_bz_div_2n_1n", "q2", &q2);
#endif
if ((res = mp_grow(&q2, q1.used+halfn)) != MP_OKAY) {
goto LBL_R;
}
for (i = q2.used; i < halfn; ++i) {
q2.dp[i] = 0;
}
for (j = 0; j < q1.used; ++j) {
q2.dp[i++] = q1.dp[j];
}
q2.used = i;
#ifdef TRACE_BZ_2N_1N_ENTRY_EXIT
fprintf(stderr, "mp_bz_div_2n_1n: n=%d, q2.used=%d, r.used=%d\n", n,
q2.used, r.used);
traceprint("mp_bz_div_2n_1n", "q2", &q2);
traceprint("mp_bz_div_2n_1n", "r", &r);
#endif
if (c != NULL) {
mp_exch(c, &q2);
}
if (d != NULL) {
mp_exch(d, &r);
}
LBL_R:
mp_clear_multi(&r, &q2, &q1, NULL);
return res;
}
/*
* mp_bz_div_3n_2n --
*
* Divide a 3n-digit number by a 2n-digit number yielding an n-digit
* quotient and a 2n-digit remainder
*
* This procedure corresponds to Algorithm 2 of the Burnikel-Ziegler paper.
*/
static int
mp_bz_div_3n_2n(int n, /* Length of the numbers */
mp_int* a, /* Dividend */
mp_int* b, /* Divisor */
mp_int* c, /* Quotient */
mp_int* d /* Remainder */
) {
mp_int aa, bb;
mp_int Qhat, Rhat, R1;
int i, j;
int altb;
int res;
if ((res = mp_init_multi(&Qhat, &R1, &Rhat, NULL)) != MP_OKAY) {
return res;
}
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: n = %d, a->used = %d, b->used = %d\n",
n, a->used, b->used); fflush(stderr);
traceprint ("mp_bz_div_3n_2n", "a", a);
traceprint ("mp_bz_div_3n_2n", "b", b);
#endif
/*
* 1.-2.
* Extract the most significant 2n digits of a and the most significant n
* digits of b.
*/
if ((res = mp_init_size(&aa, 2 * n)) != MP_OKAY) {
goto LBL_RHAT;
}
if ((res = mp_init_size(&bb, n)) != MP_OKAY) {
goto LBL_AA;
}
for (i = n; i < a->used; ++i) {
aa.dp[i-n] = a->dp[i];
}
aa.used = i-n;
for (i = n; i < b->used; ++i) {
bb.dp[i-n] = b->dp[i];
}
bb.used = i-n;
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: aa.used = %d bb.used = %d\n", aa.used, bb.used); fflush(stderr);
traceprint ("mp_bz_div_3n_2n", "aa", &aa);
traceprint ("mp_bz_div_3n_2n", "bb", &bb);
#endif
/* At this point, aa is [A1,A2] and bb is B1. Compare A1 and B1. */
altb = 0;
if (aa.used > bb.used + n) {
altb = 0;
} else if (bb.used + n > aa.used) {
altb = 1;
} else {
for (i = aa.used-1, j = bb.used-1; j >= 0; i--, j--) {
if (aa.dp[i] < bb.dp[j]) {
altb = 1;
break;
} else if (bb.dp[j] < aa.dp[i]) {
altb = 0;
break;
}
}
}
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: A1<B1? %s\n",
(altb ? "yes" : "no")); fflush(stderr);
#endif
/*
* 3. If A1 < B1, compute c = (A1,A2)/B1 with remainder d recursively.
* Otherwise, let c = 2**(n*DIGIT_BIT)-1, and d = (A1,A2)-(B1,0)+(0,B1).
*/
if (altb) {
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: recursing\n"); fflush(stderr);
#endif
if ((res = mp_bz_div_2n_1n(n, &aa, &bb, &Qhat, &R1)) != MP_OKAY) {
goto LBL_BB;
}
} else {
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: big quotient\n");
#endif
/* Qhat=(beta**n) - 1 */
if ((res = mp_grow(&Qhat, n)) != MP_OKAY) {
goto LBL_BB;
}
for (i = 0; i < n; ++i) {
Qhat.dp[i] = MP_DIGIT_MAX;
}
Qhat.used = n;
#ifdef TRACE_BZ_3N_2N
traceprint("mp_bz_div_3n_2n", "Qhat", &Qhat);
#endif
/* R1 = [A1, A2] + [0, B1] - [B1, 0] */
if ((res = mp_add(&aa, &bb, &R1)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_lshd(&bb, n)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_sub(&R1, &bb, &R1)) != MP_OKAY) {
goto LBL_BB;
}
#ifdef TRACE_BZ_3N_2N
traceprint("mp_bz_div_3n_2n", "R1", &R1);
#endif
}
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n (n = %d): first division returns:\n", n);
traceprint("mp_bz_div_3n_2n", "[A1,A2]", &aa);
traceprint("mp_bz_div_3n_2n", "[B1]", &bb);
traceprint("mp_bz_div_3n_2n", "Qhat", &Qhat);
traceprint("mp_bz_div_3n_2n", "R1", &R1);
#endif
/* Get the less significant digits of b. */
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "bb's alloc is %d, need %d\n", bb.alloc, n); fflush(stderr);
#endif
for (i = 0; i < n; ++i) {
bb.dp[i] = b->dp[i];
}
while (i < bb.used) {
bb.dp[i++] = 0;
}
bb.used = n;
mp_clamp(&bb);
/* 4. D = Qhat * B2 */
if ((res = mp_mul(&Qhat, &bb, &aa)) != MP_OKAY) {
goto LBL_BB;
}
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n (n = %d): trial multiplication:\n", n);
traceprint("mp_bz_div_3n_2n", "multiplicand Qhat", &Qhat);
traceprint("mp_bz_div_3n_2n", "multiplier B2", &bb);
traceprint("mp_bz_div_3n_2n", "product", &aa);
#endif
/*
* 5. Rhat = R1 * beta**n + A3 - D
* (Paper has a typo, it has A4 at step 5)
*/
#ifdef TRACE_BZ_3N_2N
fprintf(stderr, "mp_bz_div_3n_2n: making rhat\n");
traceprint("mp_bz_div_3n_2n", "R1", &R1);
traceprint("mp_bz_div_3n_2n", "a", a);
traceprint("mp_bz_div_3n_2n", "D", &aa);
#endif
if ((res = mp_lshd(&R1, n)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_mod_2d(a, MP_DIGIT_BIT * n, &Rhat)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_add(&Rhat, &R1, &Rhat)) != MP_OKAY) {
goto LBL_BB;
}
#ifdef TRACE_BZ_3N_2N
traceprint("mp_bz_div_3n_2n", "[R1,A3]", &Rhat);
#endif
if ((res = mp_sub(&Rhat, &aa, &Rhat)) != MP_OKAY) {
goto LBL_BB;
}
#ifdef TRACE_BZ_3N_2N
traceprint("mp_bz_div_3n_2n", "[R1,A3] - D", &Rhat);
#endif
/* 6. While Rhat < 0, Rhat += B and Qhat -= 1 */
while (Rhat.sign == MP_NEG) {
if ((res = mp_add(&Rhat, b, &Rhat)) != MP_OKAY) {
goto LBL_BB;
}
if ((res = mp_sub_d(&Qhat, 1, &Qhat)) != MP_OKAY) {
goto LBL_BB;
}
}
if (c != NULL) {
mp_exch(c,&Qhat);
}
if (d != NULL) {
mp_exch(d,&Rhat);
}
LBL_BB:
mp_clear(&bb);
LBL_AA:
mp_clear(&aa);
LBL_RHAT:
mp_clear_multi(&Rhat, &R1, &Qhat, NULL);
return res;
}
#endif