[crypto] Separate out bigint_reduce() from bigint_mod_multiply()

Faster modular multiplication algorithms such as Montgomery
multiplication will still require the ability to perform a single
direct modular reduction.

Neaten up the implementation of direct reduction and split it out into
a separate bigint_reduce() function, complete with its own unit tests.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
pull/875/merge
Michael Brown 2024-10-15 13:50:51 +01:00
parent f78c5a763c
commit 2bf16c6ffc
3 changed files with 296 additions and 37 deletions

View File

@ -34,22 +34,14 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
* Big integer support
*/
/** Modular direct reduction profiler */
static struct profiler bigint_mod_profiler __profiler =
{ .name = "bigint_mod" };
/** Modular multiplication overall profiler */
static struct profiler bigint_mod_multiply_profiler __profiler =
{ .name = "bigint_mod_multiply" };
/** Modular multiplication multiply step profiler */
static struct profiler bigint_mod_multiply_multiply_profiler __profiler =
{ .name = "bigint_mod_multiply.multiply" };
/** Modular multiplication rescale step profiler */
static struct profiler bigint_mod_multiply_rescale_profiler __profiler =
{ .name = "bigint_mod_multiply.rescale" };
/** Modular multiplication subtract step profiler */
static struct profiler bigint_mod_multiply_subtract_profiler __profiler =
{ .name = "bigint_mod_multiply.subtract" };
/**
* Conditionally swap big integers (in constant time)
*
@ -144,6 +136,175 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0,
}
}
/**
* Reduce big integer
*
* @v minuend0 Element 0 of big integer to be reduced
* @v minuend_size Number of elements in minuend
* @v modulus0 Element 0 of big integer modulus
* @v modulus_size Number of elements in modulus and result
* @v result0 Element 0 of big integer to hold result
* @v tmp Temporary working space
*/
void bigint_reduce_raw ( const bigint_element_t *minuend0,
unsigned int minuend_size,
const bigint_element_t *modulus0,
unsigned int modulus_size,
bigint_element_t *result0, void *tmp ) {
const bigint_t ( minuend_size ) __attribute__ (( may_alias ))
*minuend = ( ( const void * ) minuend0 );
const bigint_t ( modulus_size ) __attribute__ (( may_alias ))
*modulus = ( ( const void * ) modulus0 );
bigint_t ( modulus_size ) __attribute__ (( may_alias ))
*result = ( ( void * ) result0 );
struct {
bigint_t ( minuend_size ) minuend;
bigint_t ( minuend_size ) modulus;
} *temp = tmp;
const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
const bigint_element_t msb_mask = ( 1UL << ( width - 1 ) );
bigint_element_t *element;
unsigned int minuend_max;
unsigned int modulus_max;
unsigned int subshift;
bigint_element_t msb;
int offset;
int shift;
int i;
/* Start profiling */
profile_start ( &bigint_mod_profiler );
/* Sanity check */
assert ( minuend_size >= modulus_size );
assert ( sizeof ( *temp ) == bigint_reduce_tmp_len ( minuend ) );
/* Copy minuend and modulus to temporary working space */
bigint_shrink ( minuend, &temp->minuend );
bigint_grow ( modulus, &temp->modulus );
/* Normalise the modulus
*
* Scale the modulus by shifting left such that both modulus
* "m" and minuend "x" have the same most significant set bit.
* (If this is not possible, then the minuend is already less
* than the modulus, and we may therefore skip reduction
* completely.)
*/
minuend_max = bigint_max_set_bit ( minuend );
modulus_max = bigint_max_set_bit ( modulus );
shift = ( minuend_max - modulus_max );
if ( shift < 0 )
goto skip;
subshift = ( shift & ( width - 1 ) );
offset = ( shift / width );
element = temp->modulus.element;
for ( i = ( ( minuend_max - 1 ) / width ) ; ; i-- ) {
element[i] = ( element[ i - offset ] << subshift );
if ( i <= offset )
break;
if ( subshift ) {
element[i] |= ( element[ i - offset - 1 ]
>> ( width - subshift ) );
}
}
for ( i-- ; i >= 0 ; i-- )
element[i] = 0;
/* Reduce the minuend "x" by iteratively adding or subtracting
* the scaled modulus "m".
*
* On each loop iteration, we maintain the invariant:
*
* -2m <= x < 2m
*
* If x is positive, we obtain the new minuend x' by
* subtracting m, otherwise we add m:
*
* 0 <= x < 2m => x' := x - m => -m <= x' < m
* -2m <= x < 0 => x' := x + m => -m <= x' < m
*
* and then halve the modulus (by shifting right):
*
* m' = m/2
*
* We therefore end up with:
*
* -m <= x' < m => -2m' <= x' < 2m'
*
* i.e. we have preseved the invariant while reducing the
* bounds on x' by one power of two.
*
* The issue remains of how to determine on each iteration
* whether or not x is currently positive, given that both
* input values are unsigned big integers that may use all
* available bits (including the MSB).
*
* On the first loop iteration, we may simply assume that x is
* positive, since it is unmodified from the input value and
* so is positive by definition (even if the MSB is set). We
* therefore unconditionally perform a subtraction on the
* first loop iteration.
*
* Let k be the MSB after normalisation. We then have:
*
* 2^k <= m < 2^(k+1)
* 2^k <= x < 2^(k+1)
*
* On the first loop iteration, we therefore have:
*
* x' = (x - m)
* < 2^(k+1) - 2^k
* < 2^k
*
* Any positive value of x' therefore has its MSB set to zero,
* and so we may validly treat the MSB of x' as a sign bit at
* the end of the first loop iteration.
*
* On all subsequent loop iterations, the starting value m is
* guaranteed to have its MSB set to zero (since it has
* already been shifted right at least once). Since we know
* from above that we preserve the loop invariant:
*
* -m <= x' < m
*
* we immediately know that any positive value of x' also has
* its MSB set to zero, and so we may validly treat the MSB of
* x' as a sign bit at the end of all subsequent loop
* iterations.
*
* After the last loop iteration (when m' has been shifted
* back down to the original value of the modulus), we may
* need to add a single multiple of m' to ensure that x' is
* positive, i.e. lies within the range 0 <= x' < m'. To
* allow for reusing the (inlined) expansion of
* bigint_subtract(), we achieve this via a potential
* additional loop iteration that performs the addition and is
* then guaranteed to terminate (since the result will be
* positive).
*/
for ( msb = 0 ; ( msb || ( shift >= 0 ) ) ; shift-- ) {
if ( msb ) {
bigint_add ( &temp->modulus, &temp->minuend );
} else {
bigint_subtract ( &temp->modulus, &temp->minuend );
}
msb = ( temp->minuend.element[ minuend_size - 1 ] & msb_mask );
if ( shift > 0 )
bigint_shr ( &temp->modulus );
}
skip:
/* Sanity check */
assert ( ! bigint_is_geq ( &temp->minuend, &temp->modulus ) );
/* Copy result */
bigint_shrink ( &temp->minuend, result );
/* Stop profiling */
profile_stop ( &bigint_mod_profiler );
}
/**
* Perform modular multiplication of big integers
*
@ -171,8 +332,6 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
bigint_t ( size * 2 ) result;
bigint_t ( size * 2 ) modulus;
} *temp = tmp;
int shift;
int i;
/* Start profiling */
profile_start ( &bigint_mod_multiply_profiler );
@ -181,33 +340,13 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
assert ( sizeof ( *temp ) == bigint_mod_multiply_tmp_len ( modulus ) );
/* Perform multiplication */
profile_start ( &bigint_mod_multiply_multiply_profiler );
bigint_multiply ( multiplicand, multiplier, &temp->result );
profile_stop ( &bigint_mod_multiply_multiply_profiler );
/* Rescale modulus to match result */
profile_start ( &bigint_mod_multiply_rescale_profiler );
bigint_grow ( modulus, &temp->modulus );
shift = ( bigint_max_set_bit ( &temp->result ) -
bigint_max_set_bit ( &temp->modulus ) );
for ( i = 0 ; i < shift ; i++ )
bigint_shl ( &temp->modulus );
profile_stop ( &bigint_mod_multiply_rescale_profiler );
/* Subtract multiples of modulus */
profile_start ( &bigint_mod_multiply_subtract_profiler );
for ( i = 0 ; i <= shift ; i++ ) {
if ( bigint_is_geq ( &temp->result, &temp->modulus ) )
bigint_subtract ( &temp->modulus, &temp->result );
bigint_shr ( &temp->modulus );
}
profile_stop ( &bigint_mod_multiply_subtract_profiler );
/* Resize result */
bigint_shrink ( &temp->result, result );
/* Reduce result */
bigint_reduce ( &temp->result, modulus, result, temp );
/* Sanity check */
assert ( bigint_is_geq ( modulus, result ) );
assert ( ! bigint_is_geq ( result, modulus ) );
/* Stop profiling */
profile_stop ( &bigint_mod_multiply_profiler );

View File

@ -217,6 +217,35 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
multiplier_size, (result)->element ); \
} while ( 0 )
/**
* Reduce big integer
*
* @v minuend Big integer to be reduced
* @v modulus Big integer modulus
* @v result Big integer to hold result
* @v tmp Temporary working space
*/
#define bigint_reduce( minuend, modulus, result, tmp ) do { \
unsigned int minuend_size = bigint_size (minuend); \
unsigned int modulus_size = bigint_size (modulus); \
bigint_reduce_raw ( (minuend)->element, minuend_size, \
(modulus)->element, modulus_size, \
(result)->element, tmp ); \
} while ( 0 )
/**
* Calculate temporary working space required for reduction
*
* @v minuend Big integer to be reduced
* @ret len Length of temporary working space
*/
#define bigint_reduce_tmp_len( minuend ) ( { \
unsigned int size = bigint_size (minuend); \
sizeof ( struct { \
bigint_t ( size ) temp_minuend; \
bigint_t ( size ) temp_modulus; \
} ); } )
/**
* Perform modular multiplication of big integers
*
@ -339,6 +368,11 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0,
const bigint_element_t *multiplier0,
unsigned int multiplier_size,
bigint_element_t *result0 );
void bigint_reduce_raw ( const bigint_element_t *minuend0,
unsigned int minuend_size,
const bigint_element_t *modulus0,
unsigned int modulus_size,
bigint_element_t *result0, void *tmp );
void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
const bigint_element_t *multiplier0,
const bigint_element_t *modulus0,

View File

@ -185,6 +185,21 @@ void bigint_multiply_sample ( const bigint_element_t *multiplicand0,
bigint_multiply ( multiplicand, multiplier, result );
}
void bigint_reduce_sample ( const bigint_element_t *minuend0,
unsigned int minuend_size,
const bigint_element_t *modulus0,
unsigned int modulus_size,
bigint_element_t *result0, void *tmp ) {
const bigint_t ( minuend_size ) __attribute__ (( may_alias ))
*minuend = ( ( const void * ) minuend0 );
const bigint_t ( modulus_size ) __attribute__ (( may_alias ))
*modulus = ( ( const void * ) modulus0 );
bigint_t ( modulus_size ) __attribute__ (( may_alias ))
*result = ( ( void * ) result0 );
bigint_reduce ( minuend, modulus, result, tmp );
}
void bigint_mod_multiply_sample ( const bigint_element_t *multiplicand0,
const bigint_element_t *multiplier0,
const bigint_element_t *modulus0,
@ -516,6 +531,48 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0,
sizeof ( result_raw ) ) == 0 ); \
} while ( 0 )
/**
* Report result of big integer modular direct reduction test
*
* @v minuend Big integer to be reduced
* @v modulus Big integer modulus
* @v expected Big integer expected result
*/
#define bigint_reduce_ok( minuend, modulus, expected ) do { \
static const uint8_t minuend_raw[] = minuend; \
static const uint8_t modulus_raw[] = modulus; \
static const uint8_t expected_raw[] = expected; \
uint8_t result_raw[ sizeof ( expected_raw ) ]; \
unsigned int minuend_size = \
bigint_required_size ( sizeof ( minuend_raw ) ); \
unsigned int modulus_size = \
bigint_required_size ( sizeof ( modulus_raw ) ); \
bigint_t ( minuend_size ) minuend_temp; \
bigint_t ( modulus_size ) modulus_temp; \
bigint_t ( modulus_size ) result_temp; \
size_t tmp_len = bigint_reduce_tmp_len ( &minuend_temp ); \
uint8_t tmp[tmp_len]; \
{} /* Fix emacs alignment */ \
\
assert ( bigint_size ( &result_temp ) == \
bigint_size ( &modulus_temp ) ); \
bigint_init ( &minuend_temp, minuend_raw, \
sizeof ( minuend_raw ) ); \
bigint_init ( &modulus_temp, modulus_raw, \
sizeof ( modulus_raw ) ); \
DBG ( "Modular reduce:\n" ); \
DBG_HDA ( 0, &minuend_temp, sizeof ( minuend_temp ) ); \
DBG_HDA ( 0, &modulus_temp, sizeof ( modulus_temp ) ); \
bigint_reduce ( &minuend_temp, &modulus_temp, &result_temp, \
tmp ); \
DBG_HDA ( 0, &result_temp, sizeof ( result_temp ) ); \
bigint_done ( &result_temp, result_raw, \
sizeof ( result_raw ) ); \
\
ok ( memcmp ( result_raw, expected_raw, \
sizeof ( result_raw ) ) == 0 ); \
} while ( 0 )
/**
* Report result of big integer modular multiplication test
*
@ -1674,6 +1731,35 @@ static void bigint_test_exec ( void ) {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01 ) );
bigint_reduce_ok ( BIGINT ( 0x00 ),
BIGINT ( 0xaf ),
BIGINT ( 0x00 ) );
bigint_reduce_ok ( BIGINT ( 0xab ),
BIGINT ( 0xab ),
BIGINT ( 0x00 ) );
bigint_reduce_ok ( BIGINT ( 0x1d, 0x97, 0x63, 0xc9, 0x97, 0xcd, 0x43,
0xcb, 0x8e, 0x71, 0xac, 0x41, 0xdd ),
BIGINT ( 0xcc, 0x9d, 0xa0, 0x79, 0x96, 0x6a, 0x46,
0xd5, 0xb4, 0x30, 0xd2, 0x2b, 0xbf ),
BIGINT ( 0x1d, 0x97, 0x63, 0xc9, 0x97, 0xcd, 0x43,
0xcb, 0x8e, 0x71, 0xac, 0x41, 0xdd ) );
bigint_reduce_ok ( BIGINT ( 0x21, 0xfa, 0x4f, 0xce, 0x0f, 0x0f, 0x4d,
0x43, 0xaa, 0xad, 0x21, 0x30, 0xe5 ),
BIGINT ( 0x21, 0xfa, 0x4f, 0xce, 0x0f, 0x0f, 0x4d,
0x43, 0xaa, 0xad, 0x21, 0x30, 0xe5 ),
BIGINT ( 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ) );
bigint_reduce_ok ( BIGINT ( 0xf9, 0x78, 0x96, 0x39, 0xee, 0x98, 0x42,
0x6a, 0xb8, 0x74, 0x0b, 0xe8, 0x5c, 0x76,
0x34, 0xaf ),
BIGINT ( 0xf3, 0x65, 0x35, 0x41, 0x66, 0x65 ),
BIGINT ( 0xb3, 0x07, 0xe8, 0xb7, 0x01, 0xf6 ) );
bigint_reduce_ok ( BIGINT ( 0xfe, 0x30, 0xe1, 0xc6, 0x65, 0x97, 0x48,
0x2e, 0x94, 0xd4 ),
BIGINT ( 0x47, 0xaa, 0x88, 0x00, 0xd0, 0x30, 0x62,
0xfb, 0x5d, 0x55 ),
BIGINT ( 0x27, 0x31, 0x49, 0xc3, 0xf5, 0x06, 0x1f,
0x3c, 0x7c, 0xd5 ) );
bigint_mod_multiply_ok ( BIGINT ( 0x37 ),
BIGINT ( 0x67 ),
BIGINT ( 0x3f ),