From ed69931ccf4d7dc29af17917dc19d0b1b2e9200d Mon Sep 17 00:00:00 2001 From: Ava Affine Date: Mon, 8 Sep 2025 16:58:36 -0700 Subject: [PATCH] added WIP poc Signed-off-by: Ava Affine --- Makefile | 15 ++ cross_process_ssl_poc.c | 554 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 569 insertions(+) create mode 100644 Makefile create mode 100644 cross_process_ssl_poc.c diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..406a6f6 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +CC ?= gcc +FLAGS = -lssl -lcrypto + +.PHONY: run +run: + $(CC) $(FLAGS) cross_process_ssl_poc.c -g -O0 -o test + openssl req -x509 \ + -newkey rsa:4096 \ + -keyout key.pem \ + -out cert.pem \ + -sha256 -nodes \ + -days 3650 \ + -quiet \ + -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=example.com" + ./test diff --git a/cross_process_ssl_poc.c b/cross_process_ssl_poc.c new file mode 100644 index 0000000..80182f1 --- /dev/null +++ b/cross_process_ssl_poc.c @@ -0,0 +1,554 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define SHM_LEN 262144 +#define LISTEN_PORT 9125 +#define CERT_FILE "cert.pem" +#define KEY_FILE "key.pem" + +#define P(...) printf(__VA_ARGS__); fflush(stdout); + +struct alloc_record { + void *obj; + struct alloc_record *next; +}; +typedef struct alloc_record arec; + +typedef struct { + int pid; + int send_fd; + int recv_fd; +} pinfo_t; + +int listen_fd; +void *shm_start; +void **fresh_shm_head; +char *alloc_lock; +arec **alloc_recordings_head; +char *client_connect_ready; +char *begin_worker_handoff; +char *mp_error_state; +char **ssl_buf_mem_handoff; +SSL_CTX *ctx; +SSL **client_con_ssl_instance; + +pinfo_t *young, *elder, *client; + +// process loops +void client_loop(); +void elder_worker_loop(); +void young_worker_loop(); + +// mem mgmt +static void *shmalloc(size_t len, const char *file, int line); +static void *shmrealloc(void *old, size_t len, const char *file, int line); +static void shmfree(void *mem, const char *file, int line); + +// utility functions +pinfo_t fork1(void (*cb)(void)); +void prn_state(const char *procname); +int arec_find(arec **head, void *target); +void arec_append(arec **head, void *target); +void arec_delete(arec **head, void *target); + +int main() { + void *ssl_memory = shm_start = mmap(NULL, SHM_LEN, PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_ANONYMOUS, -1, 0); + if (ssl_memory == MAP_FAILED) { + P("failed to map shared memory for openssl"); + return 1; + } + + // prep fresh_shm_head + fresh_shm_head = ssl_memory; + *fresh_shm_head = (ssl_memory + sizeof(void *)); + + // prep alloc_recordings_head + alloc_recordings_head = *fresh_shm_head; + *alloc_recordings_head = NULL; + *fresh_shm_head += sizeof(void *); + + // prep client_connect_ready + client_connect_ready = *fresh_shm_head; + *client_connect_ready = 0; + *fresh_shm_head += sizeof(char); + + // prep begin_worker_handoff + begin_worker_handoff = *fresh_shm_head; + *begin_worker_handoff = 0; + *fresh_shm_head += sizeof(char); + + // prep mp_error_state + mp_error_state = *fresh_shm_head; + *mp_error_state = 0; + *fresh_shm_head += sizeof(char); + + // prep client_ssl_instance + client_con_ssl_instance = *fresh_shm_head; + *client_con_ssl_instance = NULL; + *fresh_shm_head += sizeof(void *); + + // prep pinfo_t's + young = *fresh_shm_head; + young->pid = 0; + young->recv_fd = 0; + young->send_fd = 0; + *fresh_shm_head += sizeof(pinfo_t); + + elder = *fresh_shm_head; + elder->pid = 0; + elder->recv_fd = 0; + elder->send_fd = 0; + *fresh_shm_head += sizeof(pinfo_t); + + client = *fresh_shm_head; + client->pid = 0; + client->recv_fd = 0; + client->send_fd = 0; + *fresh_shm_head += sizeof(pinfo_t); + + // prep alloc lock + alloc_lock = *fresh_shm_head; + *alloc_lock = 0; + *fresh_shm_head += sizeof(char); + + // prep ssl_buf_mem_handoff + ssl_buf_mem_handoff = *fresh_shm_head; + *ssl_buf_mem_handoff = NULL; + *fresh_shm_head += sizeof(void *); + + CRYPTO_set_mem_functions(shmalloc, shmrealloc, shmfree); + SSL_load_error_strings(); + SSL_library_init(); + OpenSSL_add_all_algorithms(); + + ctx = SSL_CTX_new(TLS_server_method()); + if (!ctx) { + P("failed to create SSL Context\n"); + return 1; + } + + if (SSL_CTX_use_certificate_file(ctx, CERT_FILE, SSL_FILETYPE_PEM) <= 0) { + P("failed to load cert %s\n", CERT_FILE); + ERR_print_errors_fp(stderr); + return 1; + } + + if (SSL_CTX_use_PrivateKey_file(ctx, KEY_FILE, SSL_FILETYPE_PEM) <= 0) { + P("failed to load key %s\n", KEY_FILE); + ERR_print_errors_fp(stderr); + return 1; + } + + // surely theres something better than this... + int on = 1; + if ((listen_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) { + P("failed to open new socket\n"); + return 1; + } + + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) { + P("failed to set SO_REUSEADDR"); + return 1; + } + + struct sockaddr_in address = { + AF_INET, htons(LISTEN_PORT), { INADDR_ANY } + }; + if (bind(listen_fd, (struct sockaddr *) &address, + sizeof(address)) < 0) { + P("failed to bind to port 9123\n"); + return 1; + } + + listen(listen_fd, 5); + + *client = fork1(&client_loop); + *elder = fork1(&elder_worker_loop); + *young = fork1(&young_worker_loop); + // wait for all children, kill all process in own process group if an error occurs + while (wait(NULL) != -1) if (*mp_error_state) kill(0, SIGKILL); + + SSL_CTX_free(ctx); + ERR_free_strings(); + EVP_cleanup(); +} + +void client_loop() { + while (!*client_connect_ready); + P("from client: client connect ready is set!\n"); + + SSL_CTX *client_ctx = SSL_CTX_new(TLS_client_method()); + if (!client_ctx) { + P("from client: failed to create new ctx\n"); + *mp_error_state = 1; + ERR_print_errors_fp(stderr); + exit(1); + } + + SSL *client = SSL_new(client_ctx); + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(LISTEN_PORT); + if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) { + P("\nInvalid address/ Address not supported \n"); + *mp_error_state = 1; + exit(1); + } + + int client_fd; + if ((client_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + P("from client: failed to open socket\n"); + *mp_error_state = 1; + exit(1); + } + + if (connect(client_fd, (struct sockaddr*) &serv_addr, sizeof(serv_addr)) < 0) { + P("from client: connection failed \n"); + *mp_error_state = 1; + exit(1); + } + + SSL_set_fd(client, client_fd); + if (SSL_connect(client) != 1) { + P("from client: failed to SSL_connect\n"); + *mp_error_state = 1; + ERR_print_errors_fp(stderr); + exit(1); + } + + char buf[9] = {0}; // CERT_FILE and null terminator + int r = SSL_read(client, buf, 9); + if (r > 0 && r != 9) { + P("from client: warning: only %d bytes read\n", r); + } else if (r <= 0) { + P("from client: failed to SSL_read\n"); + ERR_print_errors_fp(stderr); + *mp_error_state = 1; + exit(1); + } else if (strcmp(CERT_FILE, buf)) { + P("from client: warning: unexpected data: %s\n", buf); + } else { + P("from client: received expected data from elder worker\n"); + } + + SSL_write(client, KEY_FILE, sizeof(KEY_FILE)); + *client_connect_ready = 0; + while (!*client_connect_ready); + + // ++++++ SHOULD BE READING FROM NEW WORKER HERE + char buf2[9] = {0}; + int r2 = SSL_read(client, buf, 9); + if (r2 > 0 && r2 != 9) { + P("from client: truncated message from young worker %d\n", r2); + } else if (r2 <= 0) { + P("from client: failed to SSL_read from young worker\n"); + ERR_print_errors_fp(stderr); + *mp_error_state = 1; + exit(1); + } else if (strcmp(CERT_FILE, buf2)) { + P("from client: warning: unexpected data from young worker: %s\n", buf2); + } else { + P("from client: SUCCESS READING MESSAGE FROM YOUNG WORKER\n"); + } +} + +void elder_worker_loop() { + *client_connect_ready = 1; + + struct sockaddr_in addr; + uint len = sizeof(addr); // this is awful tbh + int client_fd = accept(listen_fd, (struct sockaddr *) &addr, &len); + if (client_fd < 0) { + P("from elder: failed to accept incoming connection\n"); + *mp_error_state = 1; + exit(1); + } + + (*client_con_ssl_instance) = SSL_new(ctx); + SSL_set_fd(*client_con_ssl_instance, client_fd); + if (SSL_accept(*client_con_ssl_instance) <= 0) { + P("from elder: failed to SSL_accept\n"); + ERR_print_errors_fp(stderr); + *mp_error_state = 1; + exit(1); + } + + SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)); + char buf[8] = {0}; // KEY_FILE and null terminator + int r = SSL_read(*client_con_ssl_instance, buf, 8); + if (r < 0 && r != 8) { + P("from elder: warning: only %d bytes read\n", r); + } else if (r <= 0) { + P("fomr elder: failed to SSL_read\n"); + ERR_print_errors_fp(stderr); + *mp_error_state = 1; + exit(1); + } else if (strcmp(KEY_FILE, buf)) { + P("from elder: warning: unexpected data: %s\n", buf); + } else { + P("from elder: received expected data from client worker\n"); + } + + // wait for young worker to call recvmsg + while (!*begin_worker_handoff); + + /* Allocate a char array of suitable size to hold the fd (int). + * However, since this buffer is in reality a 'struct cmsghdr', use a + * union to ensure that it is suitably aligned. + */ + union { + char buf[CMSG_SPACE(sizeof(int))]; + struct cmsghdr align; + } control_msg; + + // on linux we must send at least one byte of data in order to send + // the ancillary data (control messge) so here is some random int + int data = 0; + struct iovec iov; + iov.iov_base = &data; + iov.iov_len = 1; + struct msghdr msg = { + NULL, 0, // address (our socket is already connected) + &iov, 1, // one byte is even less than an int. + control_msg.buf, sizeof(control_msg.buf), + 0 + }; + + struct cmsghdr *cm = CMSG_FIRSTHDR(&msg); + cm->cmsg_level = SOL_SOCKET; + cm->cmsg_type = SCM_RIGHTS; + cm->cmsg_len = CMSG_LEN(sizeof(int)); + memcpy(CMSG_DATA(cm), &client_fd, sizeof(int)); + + int res = sendmsg(elder->send_fd, &msg, 0); + if (res == -1) { + P("from elder: failed to sendmsg. errno = %d\n", errno); + *mp_error_state = 1; + exit(1); + } else if (res != msg.msg_iovlen) { + P("from elder: sendmsg transmission truncated (len: %d) (errno: %d)\n", + res, errno); + *mp_error_state = 1; + exit(1); + } + + P("from elder: sent over my fd\n"); + /* TODO: need to be able to send over the init buf somehow + struct ssl_connection_st *s = (struct ssl_connection_st *) *client_con_ssl_instance; + (*ssl_buf_mem_handoff) = _shmalloc_inner(s->init_buf->len); + */ + + while (1); +} + +void young_worker_loop() { + int data, fd; + struct iovec iov; + struct msghdr msg; + + union { + char buf[CMSG_SPACE(sizeof(int))]; + struct cmsghdr align; + } control_msg; + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + iov.iov_base = &data; + iov.iov_len = sizeof(int); + msg.msg_control = control_msg.buf; + msg.msg_controllen = sizeof(control_msg.buf); + + P("from young: triggering handoff\n"); + *begin_worker_handoff = 1; + + int res = recvmsg(elder->recv_fd, &msg, 0); + if (res == -1) { + P("from young: failed to recvmsg. errno = %d\n", errno); + *mp_error_state = 1; + exit(1); + } + P("from young: received message from elder\n"); + + struct cmsghdr *control_head = CMSG_FIRSTHDR(&msg); + if (control_head == NULL || + control_head->cmsg_len != CMSG_LEN(sizeof(int)) || + control_head->cmsg_level != SOL_SOCKET || + control_head->cmsg_type != SCM_RIGHTS) { + P("from young: incoming message was invalid\n"); + exit(1); + } + + struct cmsghdr *cm = CMSG_FIRSTHDR(&msg); + int connection_fd = *CMSG_DATA(cm); + P("from young: received fd from elder: %d\n", connection_fd); + + if(fcntl(connection_fd, F_GETFD) == -1) { + P("from young: inherited fd is not valid\n"); + *mp_error_state = 1; + exit(1); + } + + // must overwrite fd without SSL_set_fd. we dont want to deallocate the orig BIO + P("from young: getting BIOs... "); + int close = 0; + BIO *rbio, *wbio; + rbio = SSL_get_rbio(*client_con_ssl_instance); + wbio = SSL_get_wbio(*client_con_ssl_instance); + P(" Got BIOs (%p, %p)\n", rbio, wbio); + BIO_set_fd(rbio, connection_fd, BIO_NOCLOSE); + BIO_reset(rbio); + P("from young: set new fd in BIO\n"); + + if(SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)) <= 0) { + P("from young: failed to SSL_write\n"); + ERR_print_errors_fp(stderr); + *mp_error_state = 1; + exit(1); + } + P("from young: SUCCESS WRITING TO INHERITED CONNECTION"); + *client_connect_ready = 1; +} + +pinfo_t fork1(void (*cb)()) { + pinfo_t ret; + if(socketpair(AF_UNIX, SOCK_STREAM, 0, &ret.send_fd) != 0) { + P("failed to create socket pair before fork()"); + exit(-1); + } + + P("new sock pair: %d and %d\n", ret.send_fd, ret.recv_fd); + + ret.pid = fork(); + if (ret.pid) return ret; + (*cb)(); + exit(0); +} + +void prn_state(const char *procname) { + P("======== %s ========\n", procname); + P("client_connect_ready: %p\n", client_connect_ready); + P("client_connect_ready: %d\n", *client_connect_ready); + P("begin_worker_handoff: %p\n", begin_worker_handoff); + P("begin_worker_handoff: %d\n", *begin_worker_handoff); + P("mp_error_state: %p\n", mp_error_state); + P("mp_error_state: %d\n", *mp_error_state); + P("fresh_shm_head: %p\n", *fresh_shm_head); + P("===========================\n"); +} + +/* the following memory management implementation is extremely elementary + * and will fragment massively. It has no ability to reclaim deallocated + * memory. it is only implemented here so that I can manage dynamic allocations + * into shared memory in a way that drives forward this proof of concept. + */ + +int arec_find(arec **head, void *target) { + if (head == NULL) { + P("alloc record list corrupt or uninitialized\n"); + exit(1); + } + + if (*head != NULL) { + return (*head)->obj == target ? 1 : arec_find(&(*head)->next, target); + } + + return 0; +} + +void arec_append(arec **head, void *target) { + if (head == NULL) { + P("alloc record list corrupt or uninitialized\n"); + exit(1); + } + + if (*head != NULL) return arec_append(&(*head)->next, target); + + *head = *fresh_shm_head; + (*head)->obj = target; + (*head)->next = NULL; + *fresh_shm_head += sizeof(arec); +} + +// doesnt actually delete anything, just skips a record. +void arec_delete(arec **head, void *target) { + if (head == NULL) { + P("alloc record list corrupt or uninitialized\n"); + exit(1); + } + + if (*head != NULL) { + if ((*head)->obj == target) { + arec *tmp = *head; + *head = tmp->next; + } + } +} + +static void *_shmalloc_inner(size_t len) { + void *target = *fresh_shm_head; + *fresh_shm_head += len; + arec_append(alloc_recordings_head, target); + return target; +} + +static void *shmalloc(size_t len, const char *file, int line) { + if (file && ( // need to catch the following + !strcmp(file, "ssl/ssl_lib.c") || // SSL + !strcmp(file, "crypto/bio/bio_lib.c") || // BIO + !strcmp(file, "ssl/ssl_sess.c") || // SSL_SESSION + 0)) { + //P("caught SSL alloc @ %s:%d\n", file, line); + while(*alloc_lock); + *alloc_lock = 1; + void *t = _shmalloc_inner(len); + *alloc_lock = 0; + return t; + } + + return malloc(len); +} + +static void *shmrealloc(void *old, size_t len, const char *file, int line) { + if (arec_find(alloc_recordings_head, old)) { + //P("caught SSL realloc @ %s:%d\n", file, line); + arec_delete(alloc_recordings_head, old); + while(*alloc_lock); + *alloc_lock = 1; + void *t = _shmalloc_inner(len); + *alloc_lock = 0; + return t; + } + + return realloc(old, len); +} + +static void shmfree(void *mem, const char *file, int line) { + if (mem > shm_start && mem < (shm_start + SHM_LEN)) { + if (!arec_find(alloc_recordings_head, mem)) { + P("WARNING: MISSING RECORD FOR %p\n", mem); + } + //P("caught SSL delete @ %s:%d\n", file, line); + while(*alloc_lock); + *alloc_lock = 1; + arec_delete(alloc_recordings_head, mem); + *alloc_lock = 0; + return; + } + + return free(mem); +}