diff --git a/src/platforms/generic_unix/lib/socket_driver.c b/src/platforms/generic_unix/lib/socket_driver.c index 7e9439622d..f7f35ff2e3 100644 --- a/src/platforms/generic_unix/lib/socket_driver.c +++ b/src/platforms/generic_unix/lib/socket_driver.c @@ -47,6 +47,8 @@ #include #include +void sys_unregister_listener_nolock(GlobalContext *global, struct EventListener *listener); + // #define ENABLE_TRACE #include "trace.h" @@ -253,8 +255,8 @@ static term init_udp_socket(Context *ctx, SocketDriverData *socket_data, term pa listener->base.handler = active_recvfrom_callback; listener->buf_size = socket_data->buffer; listener->process_id = ctx->process_id; - sys_register_listener(glb, &listener->base); socket_data->active_listener = listener; + sys_register_listener(glb, &listener->base); } } return ret; @@ -338,8 +340,8 @@ static term init_client_tcp_socket(Context *ctx, SocketDriverData *socket_data, listener->base.handler = active_recv_callback; listener->buf_size = socket_data->buffer; listener->process_id = ctx->process_id; - sys_register_listener(glb, &listener->base); socket_data->active_listener = listener; + sys_register_listener(glb, &listener->base); } } return ret; @@ -1015,8 +1017,8 @@ static void do_recv(Context *ctx, term pid, term ref, term length, term timeout, listener->length = term_to_int(length); listener->buffer = socket_data->buffer; listener->ref_ticks = term_to_ref_ticks(ref); - sys_register_listener(glb, &listener->base); socket_data->passive_listener = listener; + sys_register_listener(glb, &listener->base); } void socket_driver_do_recvfrom(Context *ctx, term pid, term ref, term length, term timeout) @@ -1044,6 +1046,13 @@ static EventListener *accept_callback(GlobalContext *glb, EventListener *base_li socklen_t clientlen = sizeof(clientaddr); int fd = accept(listener->base.fd, (struct sockaddr *) &clientaddr, &clientlen); Context *ctx = globalcontext_get_process_lock(glb, listener->process_id); + if (UNLIKELY(ctx == NULL)) { + if (fd != -1) { + close(fd); + } + free(listener); + return NULL; + } SocketDriverData *socket_data = (SocketDriverData *) ctx->platform_data; EventListener *result = NULL; if (fd == -1) { @@ -1058,6 +1067,21 @@ static EventListener *accept_callback(GlobalContext *glb, EventListener *base_li TRACE("socket_driver|accept_callback: accepted connection. fd: %i\n", fd); term pid = listener->pid; + if (UNLIKELY(fcntl(fd, F_SETFL, O_NONBLOCK) == -1)) { + int err = errno; + close(fd); + BEGIN_WITH_STACK_HEAP(12, heap); + term ref = term_from_ref_ticks(listener->ref_ticks, &heap); + term reply = port_heap_create_reply(&heap, ref, port_heap_create_sys_error_tuple(&heap, FCNTL_ATOM, err)); + port_send_message_nolock(glb, pid, reply); + END_WITH_STACK_HEAP(heap, glb); + globalcontext_get_process_unlock(glb, ctx); + if (socket_data->passive_listener) { + socket_data->passive_listener = NULL; + free(listener); + } + return NULL; + } SocketDriverData *new_socket_data = socket_driver_create_data(); new_socket_data->sockfd = fd; new_socket_data->proto = socket_data->proto; @@ -1117,8 +1141,8 @@ void socket_driver_do_accept(Context *ctx, term pid, term ref, term timeout) listener->length = 0; listener->buffer = 0; listener->ref_ticks = term_to_ref_ticks(ref); - sys_register_listener(glb, &listener->base); socket_data->passive_listener = listener; + sys_register_listener(glb, &listener->base); } static NativeHandlerResult socket_consume_mailbox(Context *ctx) @@ -1194,31 +1218,26 @@ static NativeHandlerResult socket_consume_mailbox(Context *ctx) TRACE("close\n"); port_send_reply(ctx, pid, ref, OK_ATOM); SocketDriverData *socket_data = (SocketDriverData *) ctx->platform_data; - // Callbacks (active_recv_callback, passive_recv_callback) are called - // while glb->listeners lock is held. They may want to free the - // listener, causing a potential double free here. - // We acquire the lock on listeners here and set the listeners - // to NULL in the socket_data structure to prevent them from freeing - // the listeners. + // Callbacks (active_recv_callback, passive_recv_callback, accept_callback) + // are called while glb->listeners lock is held. They may free the + // listener and set the socket_data pointer to NULL. + // We must atomically detach the pointers AND unlink from the listeners + // list under the same lock hold, to prevent a race where the callback + // also unlinks the same listener node. synclist_wrlock(&glb->listeners); ActiveRecvListener *active_listener = socket_data->active_listener; PassiveRecvListener *passive_listener = socket_data->passive_listener; socket_data->active_listener = NULL; socket_data->passive_listener = NULL; - synclist_unlock(&glb->listeners); if (active_listener) { - // Then we unregister, which also acquires the lock. The callbacks - // may have returned NULL which means the listener would no longer - // be registered, but this will work. - sys_unregister_listener(glb, &active_listener->base); - // After the listener is unregistered, the callbacks can no longer - // be called, so we can eventually free the listener - free(active_listener); + sys_unregister_listener_nolock(glb, &active_listener->base); } if (passive_listener) { - sys_unregister_listener(glb, &passive_listener->base); - free(passive_listener); + sys_unregister_listener_nolock(glb, &passive_listener->base); } + synclist_unlock(&glb->listeners); + free(active_listener); + free(passive_listener); socket_driver_do_close(ctx); // We don't need to remove message. return NativeTerminate; diff --git a/src/platforms/generic_unix/lib/sys.c b/src/platforms/generic_unix/lib/sys.c index 709fd4ba8c..1056ff4d47 100644 --- a/src/platforms/generic_unix/lib/sys.c +++ b/src/platforms/generic_unix/lib/sys.c @@ -698,6 +698,12 @@ void sys_unregister_listener(GlobalContext *global, struct EventListener *listen synclist_remove(&global->listeners, &listener->listeners_list_head); } +void sys_unregister_listener_nolock(GlobalContext *global, struct EventListener *listener) +{ + listener_event_remove_from_polling_set(listener->fd, global); + list_remove(&listener->listeners_list_head); +} + void sys_register_select_event(GlobalContext *global, ErlNifEvent event, bool is_write) { struct GenericUnixPlatformData *platform = global->platform_data;