[crypto] Use Montgomery reduction for modular exponentiation

Speed up modular exponentiation by using Montgomery reduction rather
than direct modular reduction.

Montgomery reduction in base 2^n requires the modulus to be coprime to
2^n, which would limit us to requiring that the modulus is an odd
number.  Extend the implementation to include support for
exponentiation with even moduli via Garner's algorithm as described in
"Montgomery reduction with even modulus" (Koç, 1994).

Since almost all use cases for modular exponentation require a large
prime (and hence odd) modulus, the support for even moduli could
potentially be removed in future.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
pull/1354/head
Michael Brown 2024-11-25 15:59:22 +00:00
parent 4f7dd7fbba
commit 83ac98ce22
5 changed files with 164 additions and 29 deletions

View File

@ -505,25 +505,142 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
*exponent = ( ( const void * ) exponent0 );
bigint_t ( size ) __attribute__ (( may_alias )) *result =
( ( void * ) result0 );
size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus );
const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
struct {
bigint_t ( size ) base;
bigint_t ( exponent_size ) exponent;
uint8_t mod_multiply[mod_multiply_len];
union {
bigint_t ( 2 * size ) padded_modulus;
struct {
bigint_t ( size ) modulus;
bigint_t ( size ) stash;
};
};
union {
bigint_t ( 2 * size ) full;
bigint_t ( size ) low;
} product;
} *temp = tmp;
static const uint8_t start[1] = { 0x01 };
const uint8_t one[1] = { 1 };
bigint_t ( 1 ) modinv;
bigint_element_t submask;
unsigned int subsize;
unsigned int scale;
unsigned int max;
unsigned int bit;
memcpy ( &temp->base, base, sizeof ( temp->base ) );
memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) );
bigint_init ( result, start, sizeof ( start ) );
/* Sanity check */
assert ( sizeof ( *temp ) == bigint_mod_exp_tmp_len ( modulus ) );
while ( ! bigint_is_zero ( &temp->exponent ) ) {
if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) {
bigint_mod_multiply ( result, &temp->base, modulus,
result, temp->mod_multiply );
/* Handle degenerate case of zero modulus */
if ( ! bigint_max_set_bit ( modulus ) ) {
memset ( result, 0, sizeof ( *result ) );
return;
}
/* Factor modulus as (N * 2^scale) where N is odd */
bigint_grow ( modulus, &temp->padded_modulus );
for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ;
scale++ ) {
bigint_shr ( &temp->modulus );
}
subsize = ( ( scale + width - 1 ) / width );
submask = ( ( 1UL << ( scale % width ) ) - 1 );
if ( ! submask )
submask = ~submask;
/* Calculate inverse of (scaled) modulus N modulo element size */
bigint_mod_invert ( &temp->modulus, &modinv );
/* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */
memset ( &temp->product.full, 0, sizeof ( temp->product.full ) );
bigint_subtract ( &temp->padded_modulus, &temp->product.full );
bigint_reduce ( &temp->padded_modulus, &temp->product.full );
bigint_copy ( &temp->product.low, &temp->stash );
/* Initialise result = Montgomery(1, R^2 mod N) */
bigint_montgomery ( &temp->modulus, &modinv,
&temp->product.full, result );
/* Convert base into Montgomery form */
bigint_multiply ( base, &temp->stash, &temp->product.full );
bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
&temp->stash );
/* Calculate x1 = base^exponent modulo N */
max = bigint_max_set_bit ( exponent );
for ( bit = 1 ; bit <= max ; bit++ ) {
/* Square (and reduce) */
bigint_multiply ( result, result, &temp->product.full );
bigint_montgomery ( &temp->modulus, &modinv,
&temp->product.full, result );
/* Multiply (and reduce) */
bigint_multiply ( &temp->stash, result, &temp->product.full );
bigint_montgomery ( &temp->modulus, &modinv,
&temp->product.full, &temp->product.low );
/* Conditionally swap the multiplied result */
bigint_swap ( result, &temp->product.low,
bigint_bit_is_set ( exponent, ( max - bit ) ) );
}
/* Convert back out of Montgomery form */
bigint_grow ( result, &temp->product.full );
bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
result );
/* Handle even moduli via Garner's algorithm */
if ( subsize ) {
const bigint_t ( subsize ) __attribute__ (( may_alias ))
*subbase = ( ( const void * ) base );
bigint_t ( subsize ) __attribute__ (( may_alias ))
*submodulus = ( ( void * ) &temp->modulus );
bigint_t ( subsize ) __attribute__ (( may_alias ))
*substash = ( ( void * ) &temp->stash );
bigint_t ( subsize ) __attribute__ (( may_alias ))
*subresult = ( ( void * ) result );
union {
bigint_t ( 2 * subsize ) full;
bigint_t ( subsize ) low;
} __attribute__ (( may_alias ))
*subproduct = ( ( void * ) &temp->product.full );
/* Calculate x2 = base^exponent modulo 2^k */
bigint_init ( substash, one, sizeof ( one ) );
for ( bit = 1 ; bit <= max ; bit++ ) {
/* Square (and reduce) */
bigint_multiply ( substash, substash,
&subproduct->full );
bigint_copy ( &subproduct->low, substash );
/* Multiply (and reduce) */
bigint_multiply ( subbase, substash,
&subproduct->full );
/* Conditionally swap the multiplied result */
bigint_swap ( substash, &subproduct->low,
bigint_bit_is_set ( exponent,
( max - bit ) ) );
}
bigint_shr ( &temp->exponent );
bigint_mod_multiply ( &temp->base, &temp->base, modulus,
&temp->base, temp->mod_multiply );
/* Calculate N^-1 modulo 2^k */
bigint_mod_invert ( submodulus, &subproduct->low );
bigint_copy ( &subproduct->low, submodulus );
/* Calculate y = (x2 - x1) * N^-1 modulo 2^k */
bigint_subtract ( subresult, substash );
bigint_multiply ( substash, submodulus, &subproduct->full );
subproduct->low.element[ subsize - 1 ] &= submask;
bigint_grow ( &subproduct->low, &temp->stash );
/* Reconstruct N */
bigint_mod_invert ( submodulus, &subproduct->low );
bigint_copy ( &subproduct->low, submodulus );
/* Calculate x = x1 + N * y */
bigint_multiply ( &temp->modulus, &temp->stash,
&temp->product.full );
bigint_add ( &temp->product.low, result );
}
}

View File

@ -57,8 +57,7 @@ int dhe_key ( const void *modulus, size_t len, const void *generator,
unsigned int size = bigint_required_size ( len );
unsigned int private_size = bigint_required_size ( private_len );
bigint_t ( size ) *mod;
bigint_t ( private_size ) *exp;
size_t tmp_len = bigint_mod_exp_tmp_len ( mod, exp );
size_t tmp_len = bigint_mod_exp_tmp_len ( mod );
struct {
bigint_t ( size ) modulus;
bigint_t ( size ) generator;

View File

@ -109,8 +109,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
unsigned int size = bigint_required_size ( modulus_len );
unsigned int exponent_size = bigint_required_size ( exponent_len );
bigint_t ( size ) *modulus;
bigint_t ( exponent_size ) *exponent;
size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent );
size_t tmp_len = bigint_mod_exp_tmp_len ( modulus );
struct {
bigint_t ( size ) modulus;
bigint_t ( exponent_size ) exponent;

View File

@ -322,18 +322,12 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
* Calculate temporary working space required for moduluar exponentiation
*
* @v modulus Big integer modulus
* @v exponent Big integer exponent
* @ret len Length of temporary working space
*/
#define bigint_mod_exp_tmp_len( modulus, exponent ) ( { \
#define bigint_mod_exp_tmp_len( modulus ) ( { \
unsigned int size = bigint_size (modulus); \
unsigned int exponent_size = bigint_size (exponent); \
size_t mod_multiply_len = \
bigint_mod_multiply_tmp_len (modulus); \
sizeof ( struct { \
bigint_t ( size ) temp_base; \
bigint_t ( exponent_size ) temp_exponent; \
uint8_t mod_multiply[mod_multiply_len]; \
bigint_t ( size ) temp[4]; \
} ); } )
#include <bits/bigint.h>

View File

@ -746,8 +746,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0,
bigint_t ( size ) modulus_temp; \
bigint_t ( exponent_size ) exponent_temp; \
bigint_t ( size ) result_temp; \
size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp, \
&exponent_temp ); \
size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp ); \
uint8_t tmp[tmp_len]; \
{} /* Fix emacs alignment */ \
\
@ -2070,6 +2069,14 @@ static void bigint_test_exec ( void ) {
BIGINT ( 0xb9 ),
BIGINT ( 0x39, 0x68, 0xba, 0x7d ),
BIGINT ( 0x17 ) );
bigint_mod_exp_ok ( BIGINT ( 0x71, 0x4d, 0x02, 0xe9 ),
BIGINT ( 0x00, 0x00, 0x00, 0x00 ),
BIGINT ( 0x91, 0x7f, 0x4e, 0x3a, 0x5d, 0x5c ),
BIGINT ( 0x00, 0x00, 0x00, 0x00 ) );
bigint_mod_exp_ok ( BIGINT ( 0x2b, 0xf5, 0x07, 0xaf ),
BIGINT ( 0x6e, 0xb5, 0xda, 0x5a ),
BIGINT ( 0x00, 0x00, 0x00, 0x00, 0x00 ),
BIGINT ( 0x00, 0x00, 0x00, 0x01 ) );
bigint_mod_exp_ok ( BIGINT ( 0x2e ),
BIGINT ( 0xb7 ),
BIGINT ( 0x39, 0x07, 0x1b, 0x49, 0x5b, 0xea,
@ -2774,6 +2781,25 @@ static void bigint_test_exec ( void ) {
0xfa, 0x83, 0xd4, 0x7c, 0xe9, 0x77,
0x46, 0x91, 0x3a, 0x50, 0x0d, 0x6a,
0x25, 0xd0 ) );
bigint_mod_exp_ok ( BIGINT ( 0x5b, 0x80, 0xc5, 0x03, 0xb3, 0x1e,
0x46, 0x9b, 0xa3, 0x0a, 0x70, 0x43,
0x51, 0x2a, 0x4a, 0x44, 0xcb, 0x87,
0x3e, 0x00, 0x2a, 0x48, 0x46, 0xf5,
0xb3, 0xb9, 0x73, 0xa7, 0x77, 0xfc,
0x2a, 0x1d ),
BIGINT ( 0x5e, 0x8c, 0x80, 0x03, 0xe7, 0xb0,
0x45, 0x23, 0x8f, 0xe0, 0x77, 0x02,
0xc0, 0x7e, 0xfb, 0xc4, 0xbe, 0x7b,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00 ),
BIGINT ( 0x71, 0xd9, 0x38, 0xb6 ),
BIGINT ( 0x52, 0xfc, 0x73, 0x55, 0x2f, 0x86,
0x0f, 0xde, 0x04, 0xbc, 0x6d, 0xb8,
0xfd, 0x48, 0xf8, 0x8c, 0x91, 0x1c,
0xa0, 0x8a, 0x70, 0xa8, 0xc6, 0x20,
0x0a, 0x0d, 0x3b, 0x2a, 0x92, 0x65,
0x9c, 0x59 ) );
}
/** Big integer self-test */