diff --git a/src/include/gpxe/in.h b/src/include/gpxe/in.h index 89530a559..40e4d4073 100644 --- a/src/include/gpxe/in.h +++ b/src/include/gpxe/in.h @@ -62,6 +62,15 @@ struct sockaddr_in { uint16_t sin_port; /** IPv4 address */ struct in_addr sin_addr; + /** Padding + * + * This ensures that a struct @c sockaddr_tcpip is large + * enough to hold a socket address for any TCP/IP address + * family. + */ + char pad[ sizeof ( struct sockaddr ) - sizeof ( sa_family_t ) + - sizeof ( uint16_t ) + - sizeof ( struct in_addr ) ]; }; /** diff --git a/src/net/udp.c b/src/net/udp.c index 89a5b8682..8df76a445 100644 --- a/src/net/udp.c +++ b/src/net/udp.c @@ -29,10 +29,10 @@ struct udp_connection { /** Data transfer interface */ struct xfer_interface xfer; + /** Local socket address */ + struct sockaddr_tcpip local; /** Remote socket address */ struct sockaddr_tcpip peer; - /** Local port on which the connection receives packets */ - unsigned int local_port; }; /** @@ -48,22 +48,22 @@ struct tcpip_protocol udp_protocol; * Bind UDP connection to local port * * @v udp UDP connection - * @v port Local port, in network byte order, or zero * @ret rc Return status code * - * Opens the UDP connection and binds to a local port. If no local - * port is specified, the first available port will be used. + * Opens the UDP connection and binds to the specified local port. If + * no local port is specified, the first available port will be used. */ -static int udp_bind ( struct udp_connection *udp, unsigned int port ) { +static int udp_bind ( struct udp_connection *udp ) { struct udp_connection *existing; static uint16_t try_port = 1024; /* If no port specified, find the first available port */ - if ( ! port ) { + if ( ! udp->local.st_port ) { for ( ; try_port ; try_port++ ) { if ( try_port < 1024 ) continue; - if ( udp_bind ( udp, htons ( try_port ) ) == 0 ) + udp->local.st_port = htons ( try_port ); + if ( udp_bind ( udp ) == 0 ) return 0; } return -EADDRINUSE; @@ -71,16 +71,16 @@ static int udp_bind ( struct udp_connection *udp, unsigned int port ) { /* Attempt bind to local port */ list_for_each_entry ( existing, &udp_conns, list ) { - if ( existing->local_port == port ) { + if ( existing->local.st_port == udp->local.st_port ) { DBGC ( udp, "UDP %p could not bind: port %d in use\n", - udp, ntohs ( port ) ); + udp, ntohs ( udp->local.st_port ) ); return -EADDRINUSE; } } - udp->local_port = port; /* Add to UDP connection list */ - DBGC ( udp, "UDP %p bound to port %d\n", udp, ntohs ( port ) ); + DBGC ( udp, "UDP %p bound to port %d\n", + udp, ntohs ( udp->local.st_port ) ); return 0; } @@ -100,7 +100,6 @@ static int udp_open_common ( struct xfer_interface *xfer, struct sockaddr_tcpip *st_peer = ( struct sockaddr_tcpip * ) peer; struct sockaddr_tcpip *st_local = ( struct sockaddr_tcpip * ) local; struct udp_connection *udp; - unsigned int bind_port; int rc; /* Allocate and initialise structure */ @@ -111,11 +110,12 @@ static int udp_open_common ( struct xfer_interface *xfer, xfer_init ( &udp->xfer, &udp_xfer_operations, &udp->refcnt ); if ( st_peer ) memcpy ( &udp->peer, st_peer, sizeof ( udp->peer ) ); + if ( st_local ) + memcpy ( &udp->local, st_local, sizeof ( udp->local ) ); /* Bind to local port */ if ( ! promisc ) { - bind_port = ( st_local ? st_local->st_port : 0 ); - if ( ( rc = udp_bind ( udp, bind_port ) ) != 0 ) + if ( ( rc = udp_bind ( udp ) ) != 0 ) goto err; } @@ -201,7 +201,7 @@ static int udp_tx ( struct udp_connection *udp, struct io_buffer *iobuf, /* Fill in default values if not explicitly provided */ if ( ! src_port ) - src_port = udp->local_port; + src_port = udp->local.st_port; if ( ! dest ) dest = &udp->peer; @@ -231,17 +231,24 @@ static int udp_tx ( struct udp_connection *udp, struct io_buffer *iobuf, } /** - * Identify UDP connection by local port number + * Identify UDP connection by local address * - * @v local_port Local port (in network-endian order) + * @v local Local address * @ret udp UDP connection, or NULL */ -static struct udp_connection * udp_demux ( unsigned int local_port ) { +static struct udp_connection * udp_demux ( struct sockaddr_tcpip *local ) { + static const struct sockaddr_tcpip empty_sockaddr; struct udp_connection *udp; list_for_each_entry ( udp, &udp_conns, list ) { - if ( ( udp->local_port == local_port ) || - ( udp->local_port == 0 ) ) { + if ( ( ( udp->local.st_family == local->st_family ) || + ( udp->local.st_family == 0 ) ) && + ( ( udp->local.st_port == local->st_port ) || + ( udp->local.st_port == 0 ) ) && + ( ( memcmp ( udp->local.pad, local->pad, + sizeof ( udp->local.pad ) ) == 0 ) || + ( memcmp ( udp->local.pad, empty_sockaddr.pad, + sizeof ( udp->local.pad ) ) == 0 ) ) ) { return udp; } } @@ -300,7 +307,7 @@ static int udp_rx ( struct io_buffer *iobuf, struct sockaddr_tcpip *st_src, /* Parse parameters from header and strip header */ st_src->st_port = udphdr->src; st_dest->st_port = udphdr->dest; - udp = udp_demux ( udphdr->dest ); + udp = udp_demux ( st_dest ); iob_unput ( iobuf, ( iob_len ( iobuf ) - ulen ) ); iob_pull ( iobuf, sizeof ( *udphdr ) );