diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c index c3edec9d5..b13b0ac60 100644 --- a/src/crypto/bigint.c +++ b/src/crypto/bigint.c @@ -35,12 +35,14 @@ FILE_LICENCE ( GPL2_OR_LATER ); * @v multiplier0 Element 0 of big integer to be multiplied * @v modulus0 Element 0 of big integer modulus * @v result0 Element 0 of big integer to hold result + * @v size Number of elements in base, modulus, and result + * @v tmp Temporary working space */ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0, const bigint_element_t *multiplier0, const bigint_element_t *modulus0, bigint_element_t *result0, - unsigned int size ) { + unsigned int size, void *tmp ) { const bigint_t ( size ) __attribute__ (( may_alias )) *multiplicand = ( ( const void * ) multiplicand0 ); const bigint_t ( size ) __attribute__ (( may_alias )) *multiplier = @@ -49,30 +51,35 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0, ( ( const void * ) modulus0 ); bigint_t ( size ) __attribute__ (( may_alias )) *result = ( ( void * ) result0 ); - bigint_t ( size * 2 ) temp_result; - bigint_t ( size * 2 ) temp_modulus; + struct { + bigint_t ( size * 2 ) result; + bigint_t ( size * 2 ) modulus; + } *temp = tmp; int rotation; int i; + /* Sanity check */ + assert ( sizeof ( *temp ) == bigint_mod_multiply_tmp_len ( modulus ) ); + /* Perform multiplication */ - bigint_multiply ( multiplicand, multiplier, &temp_result ); + bigint_multiply ( multiplicand, multiplier, &temp->result ); /* Rescale modulus to match result */ - bigint_grow ( modulus, &temp_modulus ); - rotation = ( bigint_max_set_bit ( &temp_result ) - - bigint_max_set_bit ( &temp_modulus ) ); + bigint_grow ( modulus, &temp->modulus ); + rotation = ( bigint_max_set_bit ( &temp->result ) - + bigint_max_set_bit ( &temp->modulus ) ); for ( i = 0 ; i < rotation ; i++ ) - bigint_rol ( &temp_modulus ); + bigint_rol ( &temp->modulus ); /* Subtract multiples of modulus */ for ( i = 0 ; i <= rotation ; i++ ) { - if ( bigint_is_geq ( &temp_result, &temp_modulus ) ) - bigint_subtract ( &temp_modulus, &temp_result ); - bigint_ror ( &temp_modulus ); + if ( bigint_is_geq ( &temp->result, &temp->modulus ) ) + bigint_subtract ( &temp->modulus, &temp->result ); + bigint_ror ( &temp->modulus ); } /* Resize result */ - bigint_shrink ( &temp_result, result ); + bigint_shrink ( &temp->result, result ); /* Sanity check */ assert ( bigint_is_geq ( modulus, result ) ); @@ -87,13 +94,14 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0, * @v result0 Element 0 of big integer to hold result * @v size Number of elements in base, modulus, and result * @v exponent_size Number of elements in exponent + * @v tmp Temporary working space */ void bigint_mod_exp_raw ( const bigint_element_t *base0, const bigint_element_t *modulus0, const bigint_element_t *exponent0, bigint_element_t *result0, - unsigned int size, - unsigned int exponent_size ) { + unsigned int size, unsigned int exponent_size, + void *tmp ) { const bigint_t ( size ) __attribute__ (( may_alias )) *base = ( ( const void * ) base0 ); const bigint_t ( size ) __attribute__ (( may_alias )) *modulus = @@ -102,21 +110,25 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, *exponent = ( ( const void * ) exponent0 ); bigint_t ( size ) __attribute__ (( may_alias )) *result = ( ( void * ) result0 ); - bigint_t ( size ) temp_base; - bigint_t ( exponent_size ) temp_exponent; + size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus ); + struct { + bigint_t ( size ) base; + bigint_t ( exponent_size ) exponent; + uint8_t mod_multiply[mod_multiply_len]; + } *temp = tmp; static const uint8_t start[1] = { 0x01 }; - memcpy ( &temp_base, base, sizeof ( temp_base ) ); - memcpy ( &temp_exponent, exponent, sizeof ( temp_exponent ) ); + memcpy ( &temp->base, base, sizeof ( temp->base ) ); + memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) ); bigint_init ( result, start, sizeof ( start ) ); - while ( ! bigint_is_zero ( &temp_exponent ) ) { - if ( bigint_bit_is_set ( &temp_exponent, 0 ) ) { - bigint_mod_multiply ( result, &temp_base, - modulus, result ); + while ( ! bigint_is_zero ( &temp->exponent ) ) { + if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) { + bigint_mod_multiply ( result, &temp->base, modulus, + result, temp->mod_multiply ); } - bigint_ror ( &temp_exponent ); - bigint_mod_multiply ( &temp_base, &temp_base, modulus, - &temp_base ); + bigint_ror ( &temp->exponent ); + bigint_mod_multiply ( &temp->base, &temp->base, modulus, + &temp->base, temp->mod_multiply ); } } diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c index 4aba5cc30..a0bf39eb8 100644 --- a/src/crypto/rsa.c +++ b/src/crypto/rsa.c @@ -123,11 +123,15 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, size_t exponent_len ) { unsigned int size = bigint_required_size ( modulus_len ); unsigned int exponent_size = bigint_required_size ( exponent_len ); + bigint_t ( size ) *modulus; + bigint_t ( exponent_size ) *exponent; + size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent ); struct { bigint_t ( size ) modulus; bigint_t ( exponent_size ) exponent; bigint_t ( size ) input; bigint_t ( size ) output; + uint8_t tmp[tmp_len]; } __attribute__ (( packed )) *dynamic; /* Free any existing dynamic storage */ @@ -147,6 +151,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, context->exponent_size = exponent_size; context->input0 = &dynamic->input.element[0]; context->output0 = &dynamic->output.element[0]; + context->tmp = &dynamic->tmp; return 0; } @@ -309,7 +314,7 @@ static void rsa_cipher ( struct rsa_context *context, bigint_init ( input, in, context->max_len ); /* Perform modular exponentiation */ - bigint_mod_exp ( input, modulus, exponent, output ); + bigint_mod_exp ( input, modulus, exponent, output, context->tmp ); /* Copy out result */ bigint_done ( output, out, context->max_len ); diff --git a/src/include/ipxe/bigint.h b/src/include/ipxe/bigint.h index a21b3e5d6..97fbce245 100644 --- a/src/include/ipxe/bigint.h +++ b/src/include/ipxe/bigint.h @@ -197,16 +197,30 @@ FILE_LICENCE ( GPL2_OR_LATER ); * @v multiplier Big integer to be multiplied * @v modulus Big integer modulus * @v result Big integer to hold result + * @v tmp Temporary working space */ #define bigint_mod_multiply( multiplicand, multiplier, modulus, \ - result ) do { \ + result, tmp ) do { \ unsigned int size = bigint_size (multiplicand); \ bigint_mod_multiply_raw ( (multiplicand)->element, \ (multiplier)->element, \ (modulus)->element, \ - (result)->element, size ); \ + (result)->element, size, tmp ); \ } while ( 0 ) +/** + * Calculate temporary working space required for moduluar multiplication + * + * @v modulus Big integer modulus + * @ret len Length of temporary working space + */ +#define bigint_mod_multiply_tmp_len( modulus ) ( { \ + unsigned int size = bigint_size (modulus); \ + sizeof ( struct { \ + bigint_t ( size * 2 ) temp_result; \ + bigint_t ( size * 2 ) temp_modulus; \ + } ); } ) + /** * Perform modular exponentiation of big integers * @@ -214,15 +228,34 @@ FILE_LICENCE ( GPL2_OR_LATER ); * @v modulus Big integer modulus * @v exponent Big integer exponent * @v result Big integer to hold result + * @v tmp Temporary working space */ -#define bigint_mod_exp( base, modulus, exponent, result ) do { \ +#define bigint_mod_exp( base, modulus, exponent, result, tmp ) do { \ unsigned int size = bigint_size (base); \ unsigned int exponent_size = bigint_size (exponent); \ bigint_mod_exp_raw ( (base)->element, (modulus)->element, \ - (exponent)->element, (result)->element, \ - size, exponent_size ); \ + (exponent)->element, (result)->element, \ + size, exponent_size, tmp ); \ } while ( 0 ) +/** + * Calculate temporary working space required for moduluar exponentiation + * + * @v modulus Big integer modulus + * @v exponent Big integer exponent + * @ret len Length of temporary working space + */ +#define bigint_mod_exp_tmp_len( modulus, exponent ) ( { \ + unsigned int size = bigint_size (modulus); \ + unsigned int exponent_size = bigint_size (exponent); \ + size_t mod_multiply_len = \ + bigint_mod_multiply_tmp_len (modulus); \ + sizeof ( struct { \ + bigint_t ( size ) temp_base; \ + bigint_t ( exponent_size ) temp_exponent; \ + uint8_t mod_multiply[mod_multiply_len]; \ + } ); } ) + #include void bigint_init_raw ( bigint_element_t *value0, unsigned int size, @@ -257,12 +290,12 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0, const bigint_element_t *multiplier0, const bigint_element_t *modulus0, bigint_element_t *result0, - unsigned int size ); + unsigned int size, void *tmp ); void bigint_mod_exp_raw ( const bigint_element_t *base0, const bigint_element_t *modulus0, const bigint_element_t *exponent0, bigint_element_t *result0, - unsigned int size, - unsigned int exponent_size ); + unsigned int size, unsigned int exponent_size, + void *tmp ); #endif /* _IPXE_BIGINT_H */ diff --git a/src/include/ipxe/rsa.h b/src/include/ipxe/rsa.h index e70362ce7..87e75a82f 100644 --- a/src/include/ipxe/rsa.h +++ b/src/include/ipxe/rsa.h @@ -129,6 +129,8 @@ struct rsa_context { bigint_element_t *input0; /** Output buffer */ bigint_element_t *output0; + /** Temporary working space for modular exponentiation */ + void *tmp; }; extern struct pubkey_algorithm rsa_algorithm; diff --git a/src/tests/bigint_test.c b/src/tests/bigint_test.c index 8c9f188ed..4052131fd 100644 --- a/src/tests/bigint_test.c +++ b/src/tests/bigint_test.c @@ -162,7 +162,8 @@ void bigint_mod_multiply_sample ( const bigint_element_t *multiplicand0, const bigint_element_t *multiplier0, const bigint_element_t *modulus0, bigint_element_t *result0, - unsigned int size ) { + unsigned int size, + void *tmp ) { const bigint_t ( size ) *multiplicand __attribute__ (( may_alias )) = ( ( const void * ) multiplicand0 ); const bigint_t ( size ) *multiplier __attribute__ (( may_alias )) @@ -172,14 +173,15 @@ void bigint_mod_multiply_sample ( const bigint_element_t *multiplicand0, bigint_t ( size ) *result __attribute__ (( may_alias )) = ( ( void * ) result0 ); - bigint_mod_multiply ( multiplicand, multiplier, modulus, result ); + bigint_mod_multiply ( multiplicand, multiplier, modulus, result, tmp ); } void bigint_mod_exp_sample ( const bigint_element_t *base0, const bigint_element_t *modulus0, const bigint_element_t *exponent0, bigint_element_t *result0, - unsigned int size, unsigned int exponent_size ) { + unsigned int size, unsigned int exponent_size, + void *tmp ) { const bigint_t ( size ) *base __attribute__ (( may_alias )) = ( ( const void * ) base0 ); const bigint_t ( size ) *modulus __attribute__ (( may_alias )) @@ -189,7 +191,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, bigint_t ( size ) *result __attribute__ (( may_alias )) = ( ( void * ) result0 ); - bigint_mod_exp ( base, modulus, exponent, result ); + bigint_mod_exp ( base, modulus, exponent, result, tmp ); } /** @@ -471,6 +473,8 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, bigint_t ( size ) multiplier_temp; \ bigint_t ( size ) modulus_temp; \ bigint_t ( size ) result_temp; \ + size_t tmp_len = bigint_mod_multiply_tmp_len ( &modulus_temp ); \ + uint8_t tmp[tmp_len]; \ {} /* Fix emacs alignment */ \ \ assert ( bigint_size ( &multiplier_temp ) == \ @@ -490,7 +494,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, DBG_HDA ( 0, &multiplier_temp, sizeof ( multiplier_temp ) ); \ DBG_HDA ( 0, &modulus_temp, sizeof ( modulus_temp ) ); \ bigint_mod_multiply ( &multiplicand_temp, &multiplier_temp, \ - &modulus_temp, &result_temp ); \ + &modulus_temp, &result_temp, tmp ); \ DBG_HDA ( 0, &result_temp, sizeof ( result_temp ) ); \ bigint_done ( &result_temp, result_raw, sizeof ( result_raw ) );\ \ @@ -520,6 +524,9 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, bigint_t ( size ) modulus_temp; \ bigint_t ( exponent_size ) exponent_temp; \ bigint_t ( size ) result_temp; \ + size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp, \ + &exponent_temp ); \ + uint8_t tmp[tmp_len]; \ {} /* Fix emacs alignment */ \ \ assert ( bigint_size ( &modulus_temp ) == \ @@ -536,7 +543,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, DBG_HDA ( 0, &modulus_temp, sizeof ( modulus_temp ) ); \ DBG_HDA ( 0, &exponent_temp, sizeof ( exponent_temp ) ); \ bigint_mod_exp ( &base_temp, &modulus_temp, &exponent_temp, \ - &result_temp ); \ + &result_temp, tmp ); \ DBG_HDA ( 0, &result_temp, sizeof ( result_temp ) ); \ bigint_done ( &result_temp, result_raw, sizeof ( result_raw ) );\ \