From bb392e46353010e9b8df230389fa7a849f4ec42e Mon Sep 17 00:00:00 2001 From: beck <> Date: Wed, 13 Nov 2019 04:10:38 +0000 Subject: refactor the nc pool loop to not shut down the socket early, and to handle tls_shutdown correctly if using TLS, doing tls_shutdown correctly if we are using the -N flag ok sthen@ --- src/usr.bin/nc/netcat.c | 100 +++++++++++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 36 deletions(-) (limited to 'src') diff --git a/src/usr.bin/nc/netcat.c b/src/usr.bin/nc/netcat.c index a53fe1c4cd..1dc95e9f36 100644 --- a/src/usr.bin/nc/netcat.c +++ b/src/usr.bin/nc/netcat.c @@ -1,4 +1,4 @@ -/* $OpenBSD: netcat.c,v 1.210 2019/11/04 17:33:28 millert Exp $ */ +/* $OpenBSD: netcat.c,v 1.211 2019/11/13 04:10:38 beck Exp $ */ /* * Copyright (c) 2001 Eric Jackson * Copyright (c) 2015 Bob Beck. All rights reserved. @@ -1103,13 +1103,14 @@ void readwrite(int net_fd, struct tls *tls_ctx) { struct pollfd pfd[4]; + int gone[4] = { 0 }; int stdin_fd = STDIN_FILENO; int stdout_fd = STDOUT_FILENO; unsigned char netinbuf[BUFSIZE]; size_t netinbufpos = 0; unsigned char stdinbuf[BUFSIZE]; size_t stdinbufpos = 0; - int n, num_fds; + int n, num_fds, shutdown_netin, shutdown_netout; ssize_t ret; /* don't read from stdin if requested */ @@ -1132,17 +1133,20 @@ readwrite(int net_fd, struct tls *tls_ctx) pfd[POLL_STDOUT].fd = stdout_fd; pfd[POLL_STDOUT].events = 0; + /* used to indicate we wish to shut down the network socket */ + shutdown_netin = shutdown_netout = 0; + while (1) { /* both inputs are gone, buffers are empty, we are done */ - if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1 && + if (gone[POLL_STDIN] && gone[POLL_NETIN] && stdinbufpos == 0 && netinbufpos == 0) return; /* both outputs are gone, we can't continue */ - if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1) + if (gone[POLL_NETOUT] && gone[POLL_STDOUT]) return; /* listen and net in gone, queues empty, done */ - if (lflag && pfd[POLL_NETIN].fd == -1 && - stdinbufpos == 0 && netinbufpos == 0) + if (lflag && gone[POLL_NETIN] && stdinbufpos == 0 + && netinbufpos == 0) return; /* help says -i is for "wait between lines sent". We read and @@ -1151,6 +1155,12 @@ readwrite(int net_fd, struct tls *tls_ctx) if (iflag) sleep(iflag); + /* If it's gone, take it away from poll */ + for (n = 0; n < 4; n++) { + if (gone[n]) + pfd[n].events = pfd[n].revents = 0; + } + /* poll */ num_fds = poll(pfd, 4, timeout); @@ -1165,36 +1175,36 @@ readwrite(int net_fd, struct tls *tls_ctx) /* treat socket error conditions */ for (n = 0; n < 4; n++) { if (pfd[n].revents & (POLLERR|POLLNVAL)) { - pfd[n].fd = -1; + gone[n] = 1; } } /* reading is possible after HUP */ if (pfd[POLL_STDIN].events & POLLIN && pfd[POLL_STDIN].revents & POLLHUP && !(pfd[POLL_STDIN].revents & POLLIN)) - pfd[POLL_STDIN].fd = -1; + gone[POLL_STDIN] = 1; if (pfd[POLL_NETIN].events & POLLIN && pfd[POLL_NETIN].revents & POLLHUP && !(pfd[POLL_NETIN].revents & POLLIN)) - pfd[POLL_NETIN].fd = -1; + gone[POLL_NETIN] = 1; if (pfd[POLL_NETOUT].revents & POLLHUP) { if (Nflag) - shutdown(pfd[POLL_NETOUT].fd, SHUT_WR); - pfd[POLL_NETOUT].fd = -1; + shutdown_netout = 1; + gone[POLL_NETOUT] = 1; } - /* if HUP, stop watching stdout */ - if (pfd[POLL_STDOUT].revents & POLLHUP) - pfd[POLL_STDOUT].fd = -1; /* if no net out, stop watching stdin */ - if (pfd[POLL_NETOUT].fd == -1) - pfd[POLL_STDIN].fd = -1; + if (gone[POLL_NETOUT]) + gone[POLL_STDIN] = 1; + + /* if stdout HUP's, stop watching stdout */ + if (pfd[POLL_STDOUT].revents & POLLHUP) + gone[POLL_STDOUT] = 1; /* if no stdout, stop watching net in */ - if (pfd[POLL_STDOUT].fd == -1) { - if (pfd[POLL_NETIN].fd != -1) - shutdown(pfd[POLL_NETIN].fd, SHUT_RD); - pfd[POLL_NETIN].fd = -1; + if (gone[POLL_STDOUT]) { + shutdown_netin = 1; + gone[POLL_NETIN] = 1; } /* try to read from stdin */ @@ -1206,7 +1216,7 @@ readwrite(int net_fd, struct tls *tls_ctx) else if (ret == TLS_WANT_POLLOUT) pfd[POLL_STDIN].events = POLLOUT; else if (ret == 0 || ret == -1) - pfd[POLL_STDIN].fd = -1; + gone[POLL_STDIN] = 1; /* read something - poll net out */ if (stdinbufpos > 0) pfd[POLL_NETOUT].events = POLLOUT; @@ -1223,7 +1233,7 @@ readwrite(int net_fd, struct tls *tls_ctx) else if (ret == TLS_WANT_POLLOUT) pfd[POLL_NETOUT].events = POLLOUT; else if (ret == -1) - pfd[POLL_NETOUT].fd = -1; + gone[POLL_NETOUT] = 1; /* buffer empty - remove self from polling */ if (stdinbufpos == 0) pfd[POLL_NETOUT].events = 0; @@ -1240,17 +1250,15 @@ readwrite(int net_fd, struct tls *tls_ctx) else if (ret == TLS_WANT_POLLOUT) pfd[POLL_NETIN].events = POLLOUT; else if (ret == -1) - pfd[POLL_NETIN].fd = -1; + gone[POLL_NETIN] = 1; /* eof on net in - remove from pfd */ if (ret == 0) { - shutdown(pfd[POLL_NETIN].fd, SHUT_RD); - pfd[POLL_NETIN].fd = -1; + gone[POLL_NETIN] = 1; } if (recvlimit > 0 && ++recvcount >= recvlimit) { - if (pfd[POLL_NETIN].fd != -1) - shutdown(pfd[POLL_NETIN].fd, SHUT_RD); - pfd[POLL_NETIN].fd = -1; - pfd[POLL_STDIN].fd = -1; + shutdown_netin = 1; + gone[POLL_NETIN] = 1; + gone[POLL_STDIN] = 1; } /* read something - poll stdout */ if (netinbufpos > 0) @@ -1272,7 +1280,7 @@ readwrite(int net_fd, struct tls *tls_ctx) else if (ret == TLS_WANT_POLLOUT) pfd[POLL_STDOUT].events = POLLOUT; else if (ret == -1) - pfd[POLL_STDOUT].fd = -1; + gone[POLL_STDOUT] = 1; /* buffer empty - remove self from polling */ if (netinbufpos == 0) pfd[POLL_STDOUT].events = 0; @@ -1282,14 +1290,34 @@ readwrite(int net_fd, struct tls *tls_ctx) } /* stdin gone and queue empty? */ - if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) { - if (pfd[POLL_NETOUT].fd != -1 && Nflag) - shutdown(pfd[POLL_NETOUT].fd, SHUT_WR); - pfd[POLL_NETOUT].fd = -1; + if (gone[POLL_STDIN] && stdinbufpos == 0) { + if (Nflag) { + shutdown_netin = 1; + shutdown_netout = 1; + } + gone[POLL_NETOUT] = 1; } /* net in gone and queue empty? */ - if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0) { - pfd[POLL_STDOUT].fd = -1; + if (gone[POLL_NETIN] && netinbufpos == 0) { + if (Nflag) { + shutdown_netin = 1; + shutdown_netout = 1; + } + gone[POLL_STDOUT] = 1; + } + + /* call tls_close if any part of the network socket is closing */ + if ((shutdown_netin || shutdown_netout) && usetls) { + timeout_tls(pfd[POLL_NETIN].fd, tls_ctx, tls_close); + shutdown_netout = shutdown_netin = 1; + } + if (shutdown_netin) { + shutdown(pfd[POLL_NETIN].fd, SHUT_RD); + gone[POLL_NETIN] = 1; + } + if (shutdown_netout) { + shutdown(pfd[POLL_NETOUT].fd, SHUT_WR); + gone[POLL_NETOUT] = 1; } } } -- cgit v1.2.3-55-g6feb