#include #include #include #include #include #include #include #include #include #include #include #include #include #define SHM_LEN 262144 #define LISTEN_PORT 9125 #define CERT_FILE "cert.pem" #define KEY_FILE "key.pem" #define P(...) printf(__VA_ARGS__); fflush(stdout); struct alloc_record { void *obj; struct alloc_record *next; }; typedef struct alloc_record arec; typedef struct { int pid; int send_fd; int recv_fd; } pinfo_t; int listen_fd; void *shm_start; void **fresh_shm_head; char *alloc_lock; arec **alloc_recordings_head; char *client_connect_ready; char *begin_worker_handoff; char *mp_error_state; char **ssl_buf_mem_handoff; SSL_CTX *ctx; SSL **client_con_ssl_instance; pinfo_t *young, *elder, *client; // process loops void client_loop(); void elder_worker_loop(); void young_worker_loop(); // mem mgmt static void *shmalloc(size_t len, const char *file, int line); static void *shmrealloc(void *old, size_t len, const char *file, int line); static void shmfree(void *mem, const char *file, int line); // utility functions pinfo_t fork1(void (*cb)(void)); void prn_state(const char *procname); int arec_find(arec **head, void *target); void arec_append(arec **head, void *target); void arec_delete(arec **head, void *target); int main() { void *ssl_memory = shm_start = mmap(NULL, SHM_LEN, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0); if (ssl_memory == MAP_FAILED) { P("failed to map shared memory for openssl"); return 1; } // prep fresh_shm_head fresh_shm_head = ssl_memory; *fresh_shm_head = (ssl_memory + sizeof(void *)); // prep alloc_recordings_head alloc_recordings_head = *fresh_shm_head; *alloc_recordings_head = NULL; *fresh_shm_head += sizeof(void *); // prep client_connect_ready client_connect_ready = *fresh_shm_head; *client_connect_ready = 0; *fresh_shm_head += sizeof(char); // prep begin_worker_handoff begin_worker_handoff = *fresh_shm_head; *begin_worker_handoff = 0; *fresh_shm_head += sizeof(char); // prep mp_error_state mp_error_state = *fresh_shm_head; *mp_error_state = 0; *fresh_shm_head += sizeof(char); // prep client_ssl_instance client_con_ssl_instance = *fresh_shm_head; *client_con_ssl_instance = NULL; *fresh_shm_head += sizeof(void *); // prep pinfo_t's young = *fresh_shm_head; young->pid = 0; young->recv_fd = 0; young->send_fd = 0; *fresh_shm_head += sizeof(pinfo_t); elder = *fresh_shm_head; elder->pid = 0; elder->recv_fd = 0; elder->send_fd = 0; *fresh_shm_head += sizeof(pinfo_t); client = *fresh_shm_head; client->pid = 0; client->recv_fd = 0; client->send_fd = 0; *fresh_shm_head += sizeof(pinfo_t); // prep alloc lock alloc_lock = *fresh_shm_head; *alloc_lock = 0; *fresh_shm_head += sizeof(char); // prep ssl_buf_mem_handoff ssl_buf_mem_handoff = *fresh_shm_head; *ssl_buf_mem_handoff = NULL; *fresh_shm_head += sizeof(void *); CRYPTO_set_mem_functions(shmalloc, shmrealloc, shmfree); SSL_load_error_strings(); SSL_library_init(); OpenSSL_add_all_algorithms(); ctx = SSL_CTX_new(TLS_server_method()); if (!ctx) { P("failed to create SSL Context\n"); return 1; } if (SSL_CTX_use_certificate_file(ctx, CERT_FILE, SSL_FILETYPE_PEM) <= 0) { P("failed to load cert %s\n", CERT_FILE); ERR_print_errors_fp(stderr); return 1; } if (SSL_CTX_use_PrivateKey_file(ctx, KEY_FILE, SSL_FILETYPE_PEM) <= 0) { P("failed to load key %s\n", KEY_FILE); ERR_print_errors_fp(stderr); return 1; } // surely theres something better than this... int on = 1; if ((listen_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) { P("failed to open new socket\n"); return 1; } if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) { P("failed to set SO_REUSEADDR"); return 1; } struct sockaddr_in address = { AF_INET, htons(LISTEN_PORT), { INADDR_ANY } }; if (bind(listen_fd, (struct sockaddr *) &address, sizeof(address)) < 0) { P("failed to bind to port 9123\n"); return 1; } listen(listen_fd, 5); *client = fork1(&client_loop); *elder = fork1(&elder_worker_loop); *young = fork1(&young_worker_loop); // wait for all children, kill all process in own process group if an error occurs while (wait(NULL) != -1) if (*mp_error_state) kill(0, SIGKILL); SSL_CTX_free(ctx); ERR_free_strings(); EVP_cleanup(); } void client_loop() { while (!*client_connect_ready); P("from client: client connect ready is set!\n"); SSL_CTX *client_ctx = SSL_CTX_new(TLS_client_method()); if (!client_ctx) { P("from client: failed to create new ctx\n"); *mp_error_state = 1; ERR_print_errors_fp(stderr); exit(1); } SSL *client = SSL_new(client_ctx); struct sockaddr_in serv_addr; serv_addr.sin_family = AF_INET; serv_addr.sin_port = htons(LISTEN_PORT); if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) { P("\nInvalid address/ Address not supported \n"); *mp_error_state = 1; exit(1); } int client_fd; if ((client_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { P("from client: failed to open socket\n"); *mp_error_state = 1; exit(1); } if (connect(client_fd, (struct sockaddr*) &serv_addr, sizeof(serv_addr)) < 0) { P("from client: connection failed \n"); *mp_error_state = 1; exit(1); } SSL_set_fd(client, client_fd); if (SSL_connect(client) != 1) { P("from client: failed to SSL_connect\n"); *mp_error_state = 1; ERR_print_errors_fp(stderr); exit(1); } char buf[9] = {0}; // CERT_FILE and null terminator int r = SSL_read(client, buf, 9); if (r > 0 && r != 9) { P("from client: warning: only %d bytes read\n", r); } else if (r <= 0) { P("from client: failed to SSL_read\n"); ERR_print_errors_fp(stderr); *mp_error_state = 1; exit(1); } else if (strcmp(CERT_FILE, buf)) { P("from client: warning: unexpected data: %s\n", buf); } else { P("from client: received expected data from elder worker\n"); } SSL_write(client, KEY_FILE, sizeof(KEY_FILE)); *client_connect_ready = 0; while (!*client_connect_ready); // ++++++ SHOULD BE READING FROM NEW WORKER HERE char buf2[9] = {0}; int r2 = SSL_read(client, buf, 9); if (r2 > 0 && r2 != 9) { P("from client: truncated message from young worker %d\n", r2); } else if (r2 <= 0) { P("from client: failed to SSL_read from young worker\n"); ERR_print_errors_fp(stderr); *mp_error_state = 1; exit(1); } else if (strcmp(CERT_FILE, buf2)) { P("from client: warning: unexpected data from young worker: %s\n", buf2); } else { P("from client: SUCCESS READING MESSAGE FROM YOUNG WORKER\n"); } } void elder_worker_loop() { *client_connect_ready = 1; struct sockaddr_in addr; uint len = sizeof(addr); // this is awful tbh int client_fd = accept(listen_fd, (struct sockaddr *) &addr, &len); if (client_fd < 0) { P("from elder: failed to accept incoming connection\n"); *mp_error_state = 1; exit(1); } (*client_con_ssl_instance) = SSL_new(ctx); SSL_set_fd(*client_con_ssl_instance, client_fd); if (SSL_accept(*client_con_ssl_instance) <= 0) { P("from elder: failed to SSL_accept\n"); ERR_print_errors_fp(stderr); *mp_error_state = 1; exit(1); } SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)); char buf[8] = {0}; // KEY_FILE and null terminator int r = SSL_read(*client_con_ssl_instance, buf, 8); if (r < 0 && r != 8) { P("from elder: warning: only %d bytes read\n", r); } else if (r <= 0) { P("fomr elder: failed to SSL_read\n"); ERR_print_errors_fp(stderr); *mp_error_state = 1; exit(1); } else if (strcmp(KEY_FILE, buf)) { P("from elder: warning: unexpected data: %s\n", buf); } else { P("from elder: received expected data from client worker\n"); } // wait for young worker to call recvmsg while (!*begin_worker_handoff); /* Allocate a char array of suitable size to hold the fd (int). * However, since this buffer is in reality a 'struct cmsghdr', use a * union to ensure that it is suitably aligned. */ union { char buf[CMSG_SPACE(sizeof(int))]; struct cmsghdr align; } control_msg; // on linux we must send at least one byte of data in order to send // the ancillary data (control messge) so here is some random int int data = 0; struct iovec iov; iov.iov_base = &data; iov.iov_len = 1; struct msghdr msg = { NULL, 0, // address (our socket is already connected) &iov, 1, // one byte is even less than an int. control_msg.buf, sizeof(control_msg.buf), 0 }; struct cmsghdr *cm = CMSG_FIRSTHDR(&msg); cm->cmsg_level = SOL_SOCKET; cm->cmsg_type = SCM_RIGHTS; cm->cmsg_len = CMSG_LEN(sizeof(int)); memcpy(CMSG_DATA(cm), &client_fd, sizeof(int)); int res = sendmsg(elder->send_fd, &msg, 0); if (res == -1) { P("from elder: failed to sendmsg. errno = %d\n", errno); *mp_error_state = 1; exit(1); } else if (res != msg.msg_iovlen) { P("from elder: sendmsg transmission truncated (len: %d) (errno: %d)\n", res, errno); *mp_error_state = 1; exit(1); } P("from elder: sent over my fd\n"); /* TODO: need to be able to send over the init buf somehow struct ssl_connection_st *s = (struct ssl_connection_st *) *client_con_ssl_instance; (*ssl_buf_mem_handoff) = _shmalloc_inner(s->init_buf->len); */ while (1); } void young_worker_loop() { int data, fd; struct iovec iov; struct msghdr msg; union { char buf[CMSG_SPACE(sizeof(int))]; struct cmsghdr align; } control_msg; msg.msg_name = NULL; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; iov.iov_base = &data; iov.iov_len = sizeof(int); msg.msg_control = control_msg.buf; msg.msg_controllen = sizeof(control_msg.buf); P("from young: triggering handoff\n"); *begin_worker_handoff = 1; int res = recvmsg(elder->recv_fd, &msg, 0); if (res == -1) { P("from young: failed to recvmsg. errno = %d\n", errno); *mp_error_state = 1; exit(1); } P("from young: received message from elder\n"); struct cmsghdr *control_head = CMSG_FIRSTHDR(&msg); if (control_head == NULL || control_head->cmsg_len != CMSG_LEN(sizeof(int)) || control_head->cmsg_level != SOL_SOCKET || control_head->cmsg_type != SCM_RIGHTS) { P("from young: incoming message was invalid\n"); exit(1); } struct cmsghdr *cm = CMSG_FIRSTHDR(&msg); int connection_fd = *CMSG_DATA(cm); P("from young: received fd from elder: %d\n", connection_fd); if(fcntl(connection_fd, F_GETFD) == -1) { P("from young: inherited fd is not valid\n"); *mp_error_state = 1; exit(1); } // must overwrite fd without SSL_set_fd. we dont want to deallocate the orig BIO P("from young: getting BIOs... "); int close = 0; BIO *rbio, *wbio; rbio = SSL_get_rbio(*client_con_ssl_instance); wbio = SSL_get_wbio(*client_con_ssl_instance); P(" Got BIOs (%p, %p)\n", rbio, wbio); BIO_set_fd(rbio, connection_fd, BIO_NOCLOSE); BIO_reset(rbio); P("from young: set new fd in BIO\n"); if(SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)) <= 0) { P("from young: failed to SSL_write\n"); ERR_print_errors_fp(stderr); *mp_error_state = 1; exit(1); } P("from young: SUCCESS WRITING TO INHERITED CONNECTION"); *client_connect_ready = 1; } pinfo_t fork1(void (*cb)()) { pinfo_t ret; if(socketpair(AF_UNIX, SOCK_STREAM, 0, &ret.send_fd) != 0) { P("failed to create socket pair before fork()"); exit(-1); } P("new sock pair: %d and %d\n", ret.send_fd, ret.recv_fd); ret.pid = fork(); if (ret.pid) return ret; (*cb)(); exit(0); } void prn_state(const char *procname) { P("======== %s ========\n", procname); P("client_connect_ready: %p\n", client_connect_ready); P("client_connect_ready: %d\n", *client_connect_ready); P("begin_worker_handoff: %p\n", begin_worker_handoff); P("begin_worker_handoff: %d\n", *begin_worker_handoff); P("mp_error_state: %p\n", mp_error_state); P("mp_error_state: %d\n", *mp_error_state); P("fresh_shm_head: %p\n", *fresh_shm_head); P("===========================\n"); } /* the following memory management implementation is extremely elementary * and will fragment massively. It has no ability to reclaim deallocated * memory. it is only implemented here so that I can manage dynamic allocations * into shared memory in a way that drives forward this proof of concept. */ int arec_find(arec **head, void *target) { if (head == NULL) { P("alloc record list corrupt or uninitialized\n"); exit(1); } if (*head != NULL) { return (*head)->obj == target ? 1 : arec_find(&(*head)->next, target); } return 0; } void arec_append(arec **head, void *target) { if (head == NULL) { P("alloc record list corrupt or uninitialized\n"); exit(1); } if (*head != NULL) return arec_append(&(*head)->next, target); *head = *fresh_shm_head; (*head)->obj = target; (*head)->next = NULL; *fresh_shm_head += sizeof(arec); } // doesnt actually delete anything, just skips a record. void arec_delete(arec **head, void *target) { if (head == NULL) { P("alloc record list corrupt or uninitialized\n"); exit(1); } if (*head != NULL) { if ((*head)->obj == target) { arec *tmp = *head; *head = tmp->next; } } } static void *_shmalloc_inner(size_t len) { void *target = *fresh_shm_head; *fresh_shm_head += len; arec_append(alloc_recordings_head, target); return target; } static void *shmalloc(size_t len, const char *file, int line) { if (file && ( // need to catch the following !strcmp(file, "ssl/ssl_lib.c") || // SSL !strcmp(file, "crypto/bio/bio_lib.c") || // BIO !strcmp(file, "ssl/ssl_sess.c") || // SSL_SESSION 0)) { //P("caught SSL alloc @ %s:%d\n", file, line); while(*alloc_lock); *alloc_lock = 1; void *t = _shmalloc_inner(len); *alloc_lock = 0; return t; } return malloc(len); } static void *shmrealloc(void *old, size_t len, const char *file, int line) { if (arec_find(alloc_recordings_head, old)) { //P("caught SSL realloc @ %s:%d\n", file, line); arec_delete(alloc_recordings_head, old); while(*alloc_lock); *alloc_lock = 1; void *t = _shmalloc_inner(len); *alloc_lock = 0; return t; } return realloc(old, len); } static void shmfree(void *mem, const char *file, int line) { if (mem > shm_start && mem < (shm_start + SHM_LEN)) { if (!arec_find(alloc_recordings_head, mem)) { P("WARNING: MISSING RECORD FOR %p\n", mem); } //P("caught SSL delete @ %s:%d\n", file, line); while(*alloc_lock); *alloc_lock = 1; arec_delete(alloc_recordings_head, mem); *alloc_lock = 0; return; } return free(mem); }