LCOV - code coverage report
Current view: top level - common/stream - socket_stream.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 123 160 76.9 %
Date: 2024-04-26 00:35:57 Functions: 9 10 90.0 %

          Line data    Source code
       1             : /*
       2             :  * SPDX-License-Identifier: MPL-2.0
       3             :  *
       4             :  * This Source Code Form is subject to the terms of the Mozilla Public
       5             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       6             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       7             :  *
       8             :  * Copyright 2024 MonetDB Foundation;
       9             :  * Copyright August 2008 - 2023 MonetDB B.V.;
      10             :  * Copyright 1997 - July 2008 CWI.
      11             :  */
      12             : 
      13             : /* Generic stream handling code such as init and close */
      14             : 
      15             : #include "monetdb_config.h"
      16             : #include "stream.h"
      17             : #include "stream_internal.h"
      18             : #ifdef HAVE_SYS_TIME_H
      19             : #include <sys/time.h>
      20             : #endif
      21             : 
      22             : 
      23             : /* ------------------------------------------------------------------ */
      24             : /* streams working on a socket */
      25             : 
      26             : static ssize_t
      27     1897038 : socket_write(stream *restrict s, const void *restrict buf, size_t elmsize, size_t cnt)
      28             : {
      29     1897038 :         size_t size = elmsize * cnt, res = 0;
      30             : #ifdef NATIVE_WIN32
      31             :         int nr = 0;
      32             : #else
      33     1897038 :         ssize_t nr = 0;
      34             : #endif
      35             : 
      36     1897038 :         if (s->errkind != MNSTR_NO__ERROR)
      37             :                 return -1;
      38             : 
      39     1897038 :         if (size == 0 || elmsize == 0)
      40           0 :                 return (ssize_t) cnt;
      41             : 
      42     1897038 :         errno = 0;
      43     3793766 :         while (res < size &&
      44             :                (
      45             : #ifdef NATIVE_WIN32
      46             :                        /* send works on int, make sure the argument fits */
      47             :                        ((nr = send(s->stream_data.s, (const char *) buf + res, (int) min(size - res, 1 << 16), 0)) > 0)
      48             : #else
      49     1897035 :                        ((nr = write(s->stream_data.s, (const char *) buf + res, size - res)) > 0)
      50             : #endif
      51         176 :                        || (nr < 0 && /* syscall failed */
      52         176 :                            s->timeout > 0 &&      /* potentially timeout */
      53             : #ifdef _MSC_VER
      54             :                            WSAGetLastError() == WSAEWOULDBLOCK &&
      55             : #else
      56           0 :                            (errno == EAGAIN
      57             : #if EAGAIN != EWOULDBLOCK
      58             :                             || errno == EWOULDBLOCK
      59             : #endif
      60           0 :                                    ) && /* it was! */
      61             : #endif
      62           0 :                            s->timeout_func != NULL &&        /* callback function exists */
      63           0 :                            !s->timeout_func(s->timeout_data))     /* callback says don't stop */
      64         176 :                        ||(nr < 0 &&
      65             : #ifdef _MSC_VER
      66             :                           WSAGetLastError() == WSAEINTR
      67             : #else
      68         176 :                           errno == EINTR
      69             : #endif
      70             :                                ))       /* interrupted */
      71             :                 ) {
      72     1896728 :                 errno = 0;
      73             : #ifdef _MSC_VER
      74             :                 WSASetLastError(0);
      75             : #endif
      76     1896728 :                 if (nr > 0)
      77     1896728 :                         res += (size_t) nr;
      78             :         }
      79     1896907 :         if (res >= elmsize)
      80     1896731 :                 return (ssize_t) (res / elmsize);
      81         176 :         if (nr < 0) {
      82         176 :                 if (s->timeout > 0 &&
      83             : #ifdef _MSC_VER
      84             :                     WSAGetLastError() == WSAEWOULDBLOCK
      85             : #else
      86           0 :                     (errno == EAGAIN
      87             : #if EAGAIN != EWOULDBLOCK
      88             :                      || errno == EWOULDBLOCK
      89             : #endif
      90             :                             )
      91             : #endif
      92             :                         )
      93           0 :                         mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
      94             :                 else
      95         176 :                         mnstr_set_error_errno(s, MNSTR_WRITE_ERROR, "socket write");
      96         176 :                 return -1;
      97             :         }
      98             :         return 0;
      99             : }
     100             : 
     101             : static ssize_t
     102     7158506 : socket_read(stream *restrict s, void *restrict buf, size_t elmsize, size_t cnt)
     103             : {
     104             : #ifdef _MSC_VER
     105             :         int nr = 0;
     106             : #else
     107     7158506 :         ssize_t nr = 0;
     108             : #endif
     109     7158506 :         size_t size = elmsize * cnt;
     110             : 
     111     7158506 :         if (s->errkind != MNSTR_NO__ERROR)
     112             :                 return -1;
     113     7158506 :         if (size == 0)
     114             :                 return 0;
     115             : 
     116             : #ifdef _MSC_VER
     117             :         /* recv only takes an int parameter, and read does not accept
     118             :          * sockets */
     119             :         if (size > INT_MAX)
     120             :                 size = elmsize * (INT_MAX / elmsize);
     121             : #endif
     122     7161762 :         for (;;) {
     123     7160135 :                 if (s->timeout) {
     124     6140830 :                         int ret;
     125             : #ifdef HAVE_POLL
     126     6140830 :                         struct pollfd pfd;
     127             : 
     128     6140830 :                         pfd = (struct pollfd) {.fd = s->stream_data.s,
     129             :                                                .events = POLLIN};
     130             : 
     131     6140830 :                         ret = poll(&pfd, 1, (int) s->timeout);
     132     6141662 :                         if (ret == -1 && errno == EINTR)
     133        1627 :                                 continue;
     134     6141662 :                         if (ret == -1 || (pfd.revents & POLLERR)) {
     135          37 :                                 mnstr_set_error_errno(s, MNSTR_READ_ERROR, "poll error");
     136          55 :                                 return -1;
     137             :                         }
     138             : #else
     139             :                         struct timeval tv;
     140             :                         fd_set fds;
     141             : 
     142             :                         errno = 0;
     143             : #ifdef _MSC_VER
     144             :                         WSASetLastError(0);
     145             : #endif
     146             :                         FD_ZERO(&fds);
     147             :                         FD_SET(s->stream_data.s, &fds);
     148             :                         tv.tv_sec = s->timeout / 1000;
     149             :                         tv.tv_usec = (s->timeout % 1000) * 1000;
     150             :                         ret = select(
     151             : #ifdef _MSC_VER
     152             :                                 0,      /* ignored on Windows */
     153             : #else
     154             :                                 s->stream_data.s + 1,
     155             : #endif
     156             :                                 &fds, NULL, NULL, &tv);
     157             :                         if (ret == SOCKET_ERROR) {
     158             :                                 mnstr_set_error_errno(s, MNSTR_READ_ERROR, "select");
     159             :                                 return -1;
     160             :                         }
     161             : #endif
     162     6141625 :                         if (ret == 0) {
     163        1645 :                                 if (s->timeout_func == NULL || s->timeout_func(s->timeout_data)) {
     164          18 :                                         mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
     165          18 :                                         return -1;
     166             :                                 }
     167        1627 :                                 continue;
     168             :                         }
     169     6139980 :                         assert(ret == 1);
     170             : #ifdef HAVE_POLL
     171     6139980 :                         assert(pfd.revents & (POLLIN|POLLHUP));
     172             : #else
     173             :                         assert(FD_ISSET(s->stream_data.s, &fds));
     174             : #endif
     175             :                 }
     176             : #ifdef _MSC_VER
     177             :                 nr = recv(s->stream_data.s, buf, (int) size, 0);
     178             :                 if (nr == SOCKET_ERROR) {
     179             :                         mnstr_set_error_errno(s, MNSTR_READ_ERROR, "recv");
     180             :                         return -1;
     181             :                 }
     182             : #else
     183     7159285 :                 nr = read(s->stream_data.s, buf, size);
     184     7158506 :                 if (nr == -1) {
     185           1 :                         mnstr_set_error_errno(s, errno == EINTR ? MNSTR_INTERRUPT : MNSTR_READ_ERROR, NULL);
     186           1 :                         return -1;
     187             :                 }
     188             : #endif
     189     7158505 :                 break;
     190             :         }
     191     7158505 :         if (nr == 0) {
     192       38232 :                 s->eof = true;
     193       38232 :                 return 0;       /* end of file */
     194             :         }
     195     7120273 :         if (elmsize > 1) {
     196      957442 :                 while ((size_t) nr % elmsize != 0) {
     197             :                         /* if elmsize > 1, we really expect that "the
     198             :                          * other side" wrote complete items in a
     199             :                          * single system call, so we expect to at
     200             :                          * least receive complete items, and hence we
     201             :                          * continue reading until we did in fact
     202             :                          * receive an integral number of complete
     203             :                          * items, ignoring any timeouts (but not real
     204             :                          * errors) (note that recursion is limited
     205             :                          * since we don't propagate the element size
     206             :                          * to the recursive call) */
     207           0 :                         ssize_t n;
     208           0 :                         n = socket_read(s, (char *) buf + nr, 1, size - (size_t) nr);
     209           0 :                         if (n < 0) {
     210           0 :                                 if (s->errkind == MNSTR_NO__ERROR)
     211           0 :                                         mnstr_set_error(s, MNSTR_READ_ERROR, "socket_read failed");
     212           0 :                                 return -1;
     213             :                         }
     214           0 :                         if (n == 0)     /* unexpected end of file */
     215             :                                 break;
     216           0 :                         nr +=
     217             : #ifdef _MSC_VER
     218             :                                 (int)
     219             : #endif
     220             :                                 n;
     221             :                 }
     222             :         }
     223     7120735 :         return nr / (ssize_t) elmsize;
     224             : }
     225             : 
     226             : static void
     227       79043 : socket_close(stream *s)
     228             : {
     229       79043 :         SOCKET fd = s->stream_data.s;
     230             : 
     231       79043 :         if (fd != INVALID_SOCKET) {
     232             :                 /* Related read/write (in/out, from/to) streams
     233             :                  * share a single socket which is not dup'ed (anymore)
     234             :                  * as Windows' dup doesn't work on sockets;
     235             :                  * hence, only one of the streams must/may close that
     236             :                  * socket; we choose to let the read socket do the
     237             :                  * job, since in mapi.c it may happen that the read
     238             :                  * stream is closed before the write stream was even
     239             :                  * created.
     240             :                  */
     241       79043 :                 if (s->readonly) {
     242             : #ifdef HAVE_SHUTDOWN
     243       39523 :                         shutdown(fd, SHUT_RDWR);
     244             : #endif
     245       39523 :                         closesocket(fd);
     246             :                 }
     247             :         }
     248       79043 :         s->stream_data.s = INVALID_SOCKET;
     249       79043 : }
     250             : 
     251             : static void
     252       37977 : socket_update_timeout(stream *s)
     253             : {
     254       37977 :         SOCKET fd = s->stream_data.s;
     255       37977 :         struct timeval tv;
     256             : 
     257       37977 :         if (fd == INVALID_SOCKET)
     258           0 :                 return;
     259       37977 :         tv.tv_sec = s->timeout / 1000;
     260       37977 :         tv.tv_usec = (s->timeout % 1000) * 1000;
     261             :         /* cast to char * for Windows, no harm on "normal" systems */
     262       37977 :         if (!s->readonly)
     263           6 :                 (void) setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, (socklen_t) sizeof(tv));
     264             : }
     265             : 
     266             : #ifndef MSG_DONTWAIT
     267             : #define MSG_DONTWAIT 0
     268             : #endif
     269             : 
     270             : static int
     271    11318204 : socket_isalive(const stream *s)
     272             : {
     273    11318204 :         SOCKET fd = s->stream_data.s;
     274             : #ifdef HAVE_POLL
     275    11318204 :         struct pollfd pfd;
     276    11318204 :         int ret;
     277    11318204 :         pfd = (struct pollfd){.fd = fd};
     278    11318204 :         if ((ret = poll(&pfd, 1, 0)) == 0)
     279             :                 return 1;
     280           0 :         if (ret == -1 && errno == EINTR)
     281           0 :                 return socket_isalive(s);
     282           0 :         if (ret < 0 || pfd.revents & (POLLERR | POLLHUP))
     283             :                 return 0;
     284           0 :         assert(0);              /* unexpected revents value */
     285             :         return 0;
     286             : #else
     287             :         fd_set fds;
     288             :         struct timeval t;
     289             :         char buffer[32];
     290             : 
     291             :         t.tv_sec = 0;
     292             :         t.tv_usec = 0;
     293             :         FD_ZERO(&fds);
     294             :         FD_SET(fd, &fds);
     295             :         return select(
     296             : #ifdef _MSC_VER
     297             :                 0,      /* ignored on Windows */
     298             : #else
     299             :                 fd + 1,
     300             : #endif
     301             :                 &fds, NULL, NULL, &t) <= 0 ||
     302             :                 recv(fd, buffer, sizeof(buffer), MSG_PEEK | MSG_DONTWAIT) != 0;
     303             : #endif
     304             : }
     305             : 
     306             : static int
     307     9393682 : socket_getoob(const stream *s)
     308             : {
     309     9393682 :         SOCKET fd = s->stream_data.s;
     310             : #ifdef HAVE_POLL
     311     9393682 :         struct pollfd pfd = (struct pollfd) {
     312             :                 .fd = fd,
     313             :                 .events = POLLPRI,
     314             :         };
     315     9393682 :         if (poll(&pfd, 1, 0) > 0)
     316             : #else
     317             :         fd_set fds;
     318             :         struct timeval t = (struct timeval) {
     319             :                 .tv_sec = 0,
     320             :                 .tv_usec = 0,
     321             :         };
     322             : #ifdef FD_SETSIZE
     323             :         if (fd >= FD_SETSIZE)
     324             :                 return 0;
     325             : #endif
     326             :         FD_ZERO(&fds);
     327             :         FD_SET(fd, &fds);
     328             :         if (select(
     329             : #ifdef _MSC_VER
     330             :                         0,      /* ignored on Windows */
     331             : #else
     332             :                         fd + 1,
     333             : #endif
     334             :                         NULL, NULL, &fds, &t) > 0)
     335             : #endif
     336             :         {
     337             : #ifdef HAVE_POLL
     338         856 :                 if (pfd.revents & (POLLHUP | POLLNVAL))
     339         856 :                         return -1;
     340           2 :                 if ((pfd.revents & POLLPRI) == 0)
     341             :                         return -1;
     342             : #else
     343             :                 if (!FD_ISSET(fd, &fds))
     344             :                         return 0;
     345             : #endif
     346           0 :                 char b = 0;
     347           0 :                 switch (recv(fd, &b, 1, MSG_OOB)) {
     348             :                 case 0:
     349             :                         /* unexpectedly didn't receive a byte */
     350             :                         break;
     351           0 :                 case 1:
     352           0 :                         return b;
     353           0 :                 case -1:
     354           0 :                         perror("recv OOB");
     355           0 :                         return -1;
     356             :                 }
     357             :         }
     358             :         return 0;
     359             : }
     360             : 
     361             : static int
     362           0 : socket_putoob(const stream *s, char val)
     363             : {
     364           0 :         SOCKET fd = s->stream_data.s;
     365           0 :         if (send(fd, &val, 1, MSG_OOB) == -1) {
     366           0 :                 perror("send OOB");
     367           0 :                 return -1;
     368             :         }
     369             :         return 0;
     370             : }
     371             : 
     372             : static stream *
     373       79060 : socket_open(SOCKET sock, const char *name)
     374             : {
     375       79060 :         stream *s;
     376       79060 :         int domain = 0;
     377             : 
     378       79060 :         if (sock == INVALID_SOCKET) {
     379           0 :                 mnstr_set_open_error(name, 0, "invalid socket");
     380           0 :                 return NULL;
     381             :         }
     382       79060 :         if ((s = create_stream(name)) == NULL)
     383             :                 return NULL;
     384       79060 :         s->read = socket_read;
     385       79060 :         s->write = socket_write;
     386       79060 :         s->close = socket_close;
     387       79060 :         s->stream_data.s = sock;
     388       79060 :         s->update_timeout = socket_update_timeout;
     389       79060 :         s->isalive = socket_isalive;
     390       79060 :         s->getoob = socket_getoob;
     391       79060 :         s->putoob = socket_putoob;
     392             : 
     393       79060 :         errno = 0;
     394             : #ifdef _MSC_VER
     395             :         WSASetLastError(0);
     396             : #endif
     397             : #if defined(SO_DOMAIN)
     398             :         {
     399       79060 :                 socklen_t len = (socklen_t) sizeof(domain);
     400       79060 :                 if (getsockopt(sock, SOL_SOCKET, SO_DOMAIN, (void *) &domain, &len) == SOCKET_ERROR)
     401           0 :                         domain = AF_INET;       /* give it a value if call fails */
     402             :         }
     403             : #endif
     404             : #if defined(SO_KEEPALIVE) && !defined(WIN32)
     405       79060 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     406        8580 :                 int opt = 1;
     407        8580 :                 (void) setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (void *) &opt, sizeof(opt));
     408             :         }
     409             : #endif
     410             : #if defined(IPTOS_THROUGHPUT) && !defined(WIN32)
     411       79060 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     412        8580 :                 int tos = IPTOS_THROUGHPUT;
     413             : 
     414        8580 :                 (void) setsockopt(sock, IPPROTO_IP, IP_TOS, (void *) &tos, sizeof(tos));
     415             :         }
     416             : #endif
     417             : #ifdef TCP_NODELAY
     418       79060 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     419        8580 :                 int nodelay = 1;
     420             : 
     421        8580 :                 (void) setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (void *) &nodelay, sizeof(nodelay));
     422             :         }
     423             : #endif
     424             : #ifdef HAVE_FCNTL
     425             :         {
     426       79060 :                 int fl = fcntl(sock, F_GETFL);
     427             : 
     428       79060 :                 fl &= ~O_NONBLOCK;
     429       79060 :                 if (fcntl(sock, F_SETFL, fl) < 0) {
     430           0 :                         mnstr_set_error_errno(s, MNSTR_OPEN_ERROR, "fcntl unset O_NONBLOCK failed");
     431           0 :                         return s;
     432             :                 }
     433             :         }
     434             : #endif
     435             : 
     436             :         return s;
     437             : }
     438             : 
     439             : stream *
     440       39530 : socket_rstream(SOCKET sock, const char *name)
     441             : {
     442       39530 :         stream *s = NULL;
     443             : 
     444             : #ifdef STREAM_DEBUG
     445             :         fprintf(stderr, "socket_rstream %zd %s\n", (ssize_t) sock, name);
     446             : #endif
     447       39530 :         if ((s = socket_open(sock, name)) != NULL)
     448       39530 :                 s->binary = true;
     449       39530 :         return s;
     450             : }
     451             : 
     452             : stream *
     453       39530 : socket_wstream(SOCKET sock, const char *name)
     454             : {
     455       39530 :         stream *s;
     456             : 
     457             : #ifdef STREAM_DEBUG
     458             :         fprintf(stderr, "socket_wstream %zd %s\n", (ssize_t) sock, name);
     459             : #endif
     460       39530 :         if ((s = socket_open(sock, name)) == NULL)
     461             :                 return NULL;
     462       39530 :         s->readonly = false;
     463       39530 :         s->binary = true;
     464       39530 :         return s;
     465             : }

Generated by: LCOV version 1.14