diff --git a/t/tests-preload.c b/t/tests-preload.c index f2324f6a5..101a5625f 100644 --- a/t/tests-preload.c +++ b/t/tests-preload.c @@ -26,7 +26,8 @@ typedef struct { peername; unsigned int open:1, bound:1, - connected:1; + connected:1, + pktinfo:1; } socket_t; typedef struct { @@ -90,7 +91,7 @@ int socket(int domain, int type, int protocol) { real_sockets[fd] = (socket_t) { .used_domain = use_domain, .wanted_domain = domain, - .type = type, + .type = type & 0xff, .used_protocol = use_protocol, .wanted_protocol = protocol, .open = 1, @@ -101,12 +102,23 @@ int socket(int domain, int type, int protocol) { static const char *addr_translate(struct sockaddr_un *sun, const struct sockaddr *addr, socklen_t addrlen, + int type, int allow_anon) { const char *err; char sockname[64]; const char *any_name; unsigned int port; + const char *prefix = "unk"; + + switch (type) { + case SOCK_STREAM: + prefix = "tcp"; + break; + case SOCK_DGRAM: + prefix = "udp"; + break; + } switch (addr->sa_family) { case AF_INET:; @@ -139,7 +151,8 @@ static const char *addr_translate(struct sockaddr_un *sun, const struct sockaddr if (allow_anon) { err = "Unix socket path truncated"; - if (snprintf(sun->sun_path, sizeof(sun->sun_path), "%s/[%s]:%u", path_prefix(), any_name, port) + if (snprintf(sun->sun_path, sizeof(sun->sun_path), "%s/%s:[%s]:%u", path_prefix(), + prefix, any_name, port) >= sizeof(sun->sun_path)) goto err; @@ -151,7 +164,8 @@ static const char *addr_translate(struct sockaddr_un *sun, const struct sockaddr if (do_specific) { err = "Unix socket path truncated"; - if (snprintf(sun->sun_path, sizeof(sun->sun_path), "%s/[%s]:%u", path_prefix(), sockname, port) + if (snprintf(sun->sun_path, sizeof(sun->sun_path), "%s/%s:[%s]:%u", path_prefix(), + prefix, sockname, port) >= sizeof(sun->sun_path)) goto err; } @@ -182,6 +196,12 @@ void addr_translate_reverse(struct sockaddr_storage *sst, socklen_t *socklen, in } path++; + if (path[0] == '\0' || path[1] == '\0' || path[2] == '\0' || path[3] != ':') { + fprintf(stderr, "preload addr_translate_reverse(): missing prefix '%s'\n", path); + return; + } + path += 4; + struct sockaddr_in sin = {0,}; struct sockaddr_in6 sin6 = {0,}; socklen_t addrlen; @@ -258,7 +278,7 @@ int bind(int fd, const struct sockaddr *addr, socklen_t addrlen) { assert(s->wanted_domain == addr->sa_family); struct sockaddr_un sun; - err = addr_translate(&sun, addr, addrlen, 0); + err = addr_translate(&sun, addr, addrlen, s->type, 0); if (err) { if (!err[0]) goto do_bind; @@ -315,6 +335,7 @@ static socklen_t anon_addr(int domain, struct sockaddr_storage *sst, unsigned in } static void check_bind(int fd) { + const char *prefix = "unk"; // to make inspecting the peer address on the receiving end possible, we must bind // to some unix path name @@ -328,13 +349,23 @@ static void check_bind(int fd) { if (s->wanted_domain == AF_UNIX || s->used_domain != AF_UNIX) return; + switch (s->type) { + case SOCK_STREAM: + prefix = "tcp"; + break; + case SOCK_DGRAM: + prefix = "udp"; + break; + } + struct sockaddr_storage sst; unsigned int auto_inc = __sync_fetch_and_add(&anon_sock_inc, 1); anon_addr(s->wanted_domain, &sst, auto_inc, getpid()); struct sockaddr_un sun; sun.sun_family = AF_UNIX; - if (snprintf(sun.sun_path, sizeof(sun.sun_path), "%s/ANON.%u.%u", path_prefix(), getpid(), + if (snprintf(sun.sun_path, sizeof(sun.sun_path), "%s/%s:ANON.%u.%u", path_prefix(), prefix, + getpid(), auto_inc) >= sizeof(sun.sun_path)) fprintf(stderr, "preload socket(): failed to print anon (fd %i)\n", fd); @@ -475,7 +506,7 @@ int connect(int fd, const struct sockaddr *addr, socklen_t addrlen) { assert(s->wanted_domain == addr->sa_family); struct sockaddr_un sun; - err = addr_translate(&sun, addr, addrlen, 1); + err = addr_translate(&sun, addr, addrlen, s->type, 1); if (err) { if (!err[0]) goto do_connect; @@ -721,13 +752,13 @@ static const struct sockaddr *addr_find(const struct sockaddr *addr, socklen_t * return NULL; } -static const struct sockaddr *addr_send_translate(const struct sockaddr *addr, socklen_t *addrlen) { +static const struct sockaddr *addr_send_translate(const struct sockaddr *addr, int type, socklen_t *addrlen) { const struct sockaddr *ret = addr_find(addr, addrlen); if (ret) return ret; static __thread struct sockaddr_un sun; - const char *err = addr_translate(&sun, addr, *addrlen, 0); + const char *err = addr_translate(&sun, addr, *addrlen, type, 0); if (!err) { *addrlen = sizeof(sun); return (void *) &sun; @@ -740,19 +771,43 @@ static const struct sockaddr *addr_send_translate(const struct sockaddr *addr, s } ssize_t sendto(int fd, const void *buf, size_t len, int flags, const struct sockaddr *addr, socklen_t addrlen) { + const char *err; check_bind(fd); ssize_t (*real_sendto)(int, const void *, size_t, int, const struct sockaddr *, socklen_t) = dlsym(RTLD_NEXT, "sendto"); - addr = addr_send_translate(addr, &addrlen); + err = "fd out of bounds"; + if (fd < 0 || fd >= MAX_SOCKETS) + goto do_send_warn; + socket_t *s = &real_sockets[fd]; + err = "fd not open"; + if (!s->open) + goto do_send_warn; + addr = addr_send_translate(addr, s->type, &addrlen); + goto do_send; +do_send_warn: + fprintf(stderr, "preload sendto(): %s (fd %i)\n", err, fd); +do_send: return real_sendto(fd, buf, len, flags, addr, addrlen); } ssize_t sendmsg(int fd, const struct msghdr *msg, int flags) { + const char *err; check_bind(fd); ssize_t (*real_sendmsg)(int, const struct msghdr *, int) = dlsym(RTLD_NEXT, "sendmsg"); + err = "fd out of bounds"; + if (fd < 0 || fd >= MAX_SOCKETS) + goto do_send_warn; + socket_t *s = &real_sockets[fd]; + err = "fd not open"; + if (!s->open) + goto do_send_warn; struct msghdr msg2 = *msg; if (msg2.msg_name) - msg2.msg_name = (void *) addr_send_translate(msg2.msg_name, &msg2.msg_namelen); + msg2.msg_name = (void *) addr_send_translate(msg2.msg_name, s->type, &msg2.msg_namelen); + goto do_send; +do_send_warn: + fprintf(stderr, "preload sendmsg(): %s (fd %i)\n", err, fd); +do_send: return real_sendmsg(fd, &msg2, flags); } @@ -775,6 +830,10 @@ int setsockopt(int fd, int level, int optname, const void *optval, socklen_t opt return 0; if (level == IPPROTO_TCP && optname == TCP_NODELAY) return 0; + if (level == SOL_IP && optname == IP_PKTINFO) { + s->pktinfo = 1; + return 0; + } break; case AF_INET6: @@ -784,6 +843,10 @@ int setsockopt(int fd, int level, int optname, const void *optval, socklen_t opt return 0; if (level == IPPROTO_TCP && optname == TCP_NODELAY) return 0; + if (level == SOL_IPV6 && optname == IPV6_RECVPKTINFO) { + s->pktinfo = 1; + return 0; + } break; }