xp_ssL_poc/cross_process_ssl_poc.c

555 lines
14 KiB
C
Raw Permalink Normal View History

#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <netdb.h>
#include <unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <openssl/crypto.h>
#define SHM_LEN 262144
#define LISTEN_PORT 9125
#define CERT_FILE "cert.pem"
#define KEY_FILE "key.pem"
#define P(...) printf(__VA_ARGS__); fflush(stdout);
struct alloc_record {
void *obj;
struct alloc_record *next;
};
typedef struct alloc_record arec;
typedef struct {
int pid;
int send_fd;
int recv_fd;
} pinfo_t;
int listen_fd;
void *shm_start;
void **fresh_shm_head;
char *alloc_lock;
arec **alloc_recordings_head;
char *client_connect_ready;
char *begin_worker_handoff;
char *mp_error_state;
char **ssl_buf_mem_handoff;
SSL_CTX *ctx;
SSL **client_con_ssl_instance;
pinfo_t *young, *elder, *client;
// process loops
void client_loop();
void elder_worker_loop();
void young_worker_loop();
// mem mgmt
static void *shmalloc(size_t len, const char *file, int line);
static void *shmrealloc(void *old, size_t len, const char *file, int line);
static void shmfree(void *mem, const char *file, int line);
// utility functions
pinfo_t fork1(void (*cb)(void));
void prn_state(const char *procname);
int arec_find(arec **head, void *target);
void arec_append(arec **head, void *target);
void arec_delete(arec **head, void *target);
int main() {
void *ssl_memory = shm_start = mmap(NULL, SHM_LEN, PROT_READ | PROT_WRITE,
MAP_SHARED | MAP_ANONYMOUS, -1, 0);
if (ssl_memory == MAP_FAILED) {
P("failed to map shared memory for openssl");
return 1;
}
// prep fresh_shm_head
fresh_shm_head = ssl_memory;
*fresh_shm_head = (ssl_memory + sizeof(void *));
// prep alloc_recordings_head
alloc_recordings_head = *fresh_shm_head;
*alloc_recordings_head = NULL;
*fresh_shm_head += sizeof(void *);
// prep client_connect_ready
client_connect_ready = *fresh_shm_head;
*client_connect_ready = 0;
*fresh_shm_head += sizeof(char);
// prep begin_worker_handoff
begin_worker_handoff = *fresh_shm_head;
*begin_worker_handoff = 0;
*fresh_shm_head += sizeof(char);
// prep mp_error_state
mp_error_state = *fresh_shm_head;
*mp_error_state = 0;
*fresh_shm_head += sizeof(char);
// prep client_ssl_instance
client_con_ssl_instance = *fresh_shm_head;
*client_con_ssl_instance = NULL;
*fresh_shm_head += sizeof(void *);
// prep pinfo_t's
young = *fresh_shm_head;
young->pid = 0;
young->recv_fd = 0;
young->send_fd = 0;
*fresh_shm_head += sizeof(pinfo_t);
elder = *fresh_shm_head;
elder->pid = 0;
elder->recv_fd = 0;
elder->send_fd = 0;
*fresh_shm_head += sizeof(pinfo_t);
client = *fresh_shm_head;
client->pid = 0;
client->recv_fd = 0;
client->send_fd = 0;
*fresh_shm_head += sizeof(pinfo_t);
// prep alloc lock
alloc_lock = *fresh_shm_head;
*alloc_lock = 0;
*fresh_shm_head += sizeof(char);
// prep ssl_buf_mem_handoff
ssl_buf_mem_handoff = *fresh_shm_head;
*ssl_buf_mem_handoff = NULL;
*fresh_shm_head += sizeof(void *);
CRYPTO_set_mem_functions(shmalloc, shmrealloc, shmfree);
SSL_load_error_strings();
SSL_library_init();
OpenSSL_add_all_algorithms();
ctx = SSL_CTX_new(TLS_server_method());
if (!ctx) {
P("failed to create SSL Context\n");
return 1;
}
if (SSL_CTX_use_certificate_file(ctx, CERT_FILE, SSL_FILETYPE_PEM) <= 0) {
P("failed to load cert %s\n", CERT_FILE);
ERR_print_errors_fp(stderr);
return 1;
}
if (SSL_CTX_use_PrivateKey_file(ctx, KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
P("failed to load key %s\n", KEY_FILE);
ERR_print_errors_fp(stderr);
return 1;
}
// surely theres something better than this...
int on = 1;
if ((listen_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
P("failed to open new socket\n");
return 1;
}
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
P("failed to set SO_REUSEADDR");
return 1;
}
struct sockaddr_in address = {
AF_INET, htons(LISTEN_PORT), { INADDR_ANY }
};
if (bind(listen_fd, (struct sockaddr *) &address,
sizeof(address)) < 0) {
P("failed to bind to port 9123\n");
return 1;
}
listen(listen_fd, 5);
*client = fork1(&client_loop);
*elder = fork1(&elder_worker_loop);
*young = fork1(&young_worker_loop);
// wait for all children, kill all process in own process group if an error occurs
while (wait(NULL) != -1) if (*mp_error_state) kill(0, SIGKILL);
SSL_CTX_free(ctx);
ERR_free_strings();
EVP_cleanup();
}
void client_loop() {
while (!*client_connect_ready);
P("from client: client connect ready is set!\n");
SSL_CTX *client_ctx = SSL_CTX_new(TLS_client_method());
if (!client_ctx) {
P("from client: failed to create new ctx\n");
*mp_error_state = 1;
ERR_print_errors_fp(stderr);
exit(1);
}
SSL *client = SSL_new(client_ctx);
struct sockaddr_in serv_addr;
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(LISTEN_PORT);
if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) {
P("\nInvalid address/ Address not supported \n");
*mp_error_state = 1;
exit(1);
}
int client_fd;
if ((client_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
P("from client: failed to open socket\n");
*mp_error_state = 1;
exit(1);
}
if (connect(client_fd, (struct sockaddr*) &serv_addr, sizeof(serv_addr)) < 0) {
P("from client: connection failed \n");
*mp_error_state = 1;
exit(1);
}
SSL_set_fd(client, client_fd);
if (SSL_connect(client) != 1) {
P("from client: failed to SSL_connect\n");
*mp_error_state = 1;
ERR_print_errors_fp(stderr);
exit(1);
}
char buf[9] = {0}; // CERT_FILE and null terminator
int r = SSL_read(client, buf, 9);
if (r > 0 && r != 9) {
P("from client: warning: only %d bytes read\n", r);
} else if (r <= 0) {
P("from client: failed to SSL_read\n");
ERR_print_errors_fp(stderr);
*mp_error_state = 1;
exit(1);
} else if (strcmp(CERT_FILE, buf)) {
P("from client: warning: unexpected data: %s\n", buf);
} else {
P("from client: received expected data from elder worker\n");
}
SSL_write(client, KEY_FILE, sizeof(KEY_FILE));
*client_connect_ready = 0;
while (!*client_connect_ready);
// ++++++ SHOULD BE READING FROM NEW WORKER HERE
char buf2[9] = {0};
int r2 = SSL_read(client, buf, 9);
if (r2 > 0 && r2 != 9) {
P("from client: truncated message from young worker %d\n", r2);
} else if (r2 <= 0) {
P("from client: failed to SSL_read from young worker\n");
ERR_print_errors_fp(stderr);
*mp_error_state = 1;
exit(1);
} else if (strcmp(CERT_FILE, buf2)) {
P("from client: warning: unexpected data from young worker: %s\n", buf2);
} else {
P("from client: SUCCESS READING MESSAGE FROM YOUNG WORKER\n");
}
}
void elder_worker_loop() {
*client_connect_ready = 1;
struct sockaddr_in addr;
uint len = sizeof(addr); // this is awful tbh
int client_fd = accept(listen_fd, (struct sockaddr *) &addr, &len);
if (client_fd < 0) {
P("from elder: failed to accept incoming connection\n");
*mp_error_state = 1;
exit(1);
}
(*client_con_ssl_instance) = SSL_new(ctx);
SSL_set_fd(*client_con_ssl_instance, client_fd);
if (SSL_accept(*client_con_ssl_instance) <= 0) {
P("from elder: failed to SSL_accept\n");
ERR_print_errors_fp(stderr);
*mp_error_state = 1;
exit(1);
}
SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE));
char buf[8] = {0}; // KEY_FILE and null terminator
int r = SSL_read(*client_con_ssl_instance, buf, 8);
if (r < 0 && r != 8) {
P("from elder: warning: only %d bytes read\n", r);
} else if (r <= 0) {
P("fomr elder: failed to SSL_read\n");
ERR_print_errors_fp(stderr);
*mp_error_state = 1;
exit(1);
} else if (strcmp(KEY_FILE, buf)) {
P("from elder: warning: unexpected data: %s\n", buf);
} else {
P("from elder: received expected data from client worker\n");
}
// wait for young worker to call recvmsg
while (!*begin_worker_handoff);
/* Allocate a char array of suitable size to hold the fd (int).
* However, since this buffer is in reality a 'struct cmsghdr', use a
* union to ensure that it is suitably aligned.
*/
union {
char buf[CMSG_SPACE(sizeof(int))];
struct cmsghdr align;
} control_msg;
// on linux we must send at least one byte of data in order to send
// the ancillary data (control messge) so here is some random int
int data = 0;
struct iovec iov;
iov.iov_base = &data;
iov.iov_len = 1;
struct msghdr msg = {
NULL, 0, // address (our socket is already connected)
&iov, 1, // one byte is even less than an int.
control_msg.buf, sizeof(control_msg.buf),
0
};
struct cmsghdr *cm = CMSG_FIRSTHDR(&msg);
cm->cmsg_level = SOL_SOCKET;
cm->cmsg_type = SCM_RIGHTS;
cm->cmsg_len = CMSG_LEN(sizeof(int));
memcpy(CMSG_DATA(cm), &client_fd, sizeof(int));
int res = sendmsg(elder->send_fd, &msg, 0);
if (res == -1) {
P("from elder: failed to sendmsg. errno = %d\n", errno);
*mp_error_state = 1;
exit(1);
} else if (res != msg.msg_iovlen) {
P("from elder: sendmsg transmission truncated (len: %d) (errno: %d)\n",
res, errno);
*mp_error_state = 1;
exit(1);
}
P("from elder: sent over my fd\n");
/* TODO: need to be able to send over the init buf somehow
struct ssl_connection_st *s = (struct ssl_connection_st *) *client_con_ssl_instance;
(*ssl_buf_mem_handoff) = _shmalloc_inner(s->init_buf->len);
*/
while (1);
}
void young_worker_loop() {
int data, fd;
struct iovec iov;
struct msghdr msg;
union {
char buf[CMSG_SPACE(sizeof(int))];
struct cmsghdr align;
} control_msg;
msg.msg_name = NULL;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
iov.iov_base = &data;
iov.iov_len = sizeof(int);
msg.msg_control = control_msg.buf;
msg.msg_controllen = sizeof(control_msg.buf);
P("from young: triggering handoff\n");
*begin_worker_handoff = 1;
int res = recvmsg(elder->recv_fd, &msg, 0);
if (res == -1) {
P("from young: failed to recvmsg. errno = %d\n", errno);
*mp_error_state = 1;
exit(1);
}
P("from young: received message from elder\n");
struct cmsghdr *control_head = CMSG_FIRSTHDR(&msg);
if (control_head == NULL ||
control_head->cmsg_len != CMSG_LEN(sizeof(int)) ||
control_head->cmsg_level != SOL_SOCKET ||
control_head->cmsg_type != SCM_RIGHTS) {
P("from young: incoming message was invalid\n");
exit(1);
}
struct cmsghdr *cm = CMSG_FIRSTHDR(&msg);
int connection_fd = *CMSG_DATA(cm);
P("from young: received fd from elder: %d\n", connection_fd);
if(fcntl(connection_fd, F_GETFD) == -1) {
P("from young: inherited fd is not valid\n");
*mp_error_state = 1;
exit(1);
}
// must overwrite fd without SSL_set_fd. we dont want to deallocate the orig BIO
P("from young: getting BIOs... ");
int close = 0;
BIO *rbio, *wbio;
rbio = SSL_get_rbio(*client_con_ssl_instance);
wbio = SSL_get_wbio(*client_con_ssl_instance);
P(" Got BIOs (%p, %p)\n", rbio, wbio);
BIO_set_fd(rbio, connection_fd, BIO_NOCLOSE);
BIO_reset(rbio);
P("from young: set new fd in BIO\n");
if(SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)) <= 0) {
P("from young: failed to SSL_write\n");
ERR_print_errors_fp(stderr);
*mp_error_state = 1;
exit(1);
}
P("from young: SUCCESS WRITING TO INHERITED CONNECTION");
*client_connect_ready = 1;
}
pinfo_t fork1(void (*cb)()) {
pinfo_t ret;
if(socketpair(AF_UNIX, SOCK_STREAM, 0, &ret.send_fd) != 0) {
P("failed to create socket pair before fork()");
exit(-1);
}
P("new sock pair: %d and %d\n", ret.send_fd, ret.recv_fd);
ret.pid = fork();
if (ret.pid) return ret;
(*cb)();
exit(0);
}
void prn_state(const char *procname) {
P("======== %s ========\n", procname);
P("client_connect_ready: %p\n", client_connect_ready);
P("client_connect_ready: %d\n", *client_connect_ready);
P("begin_worker_handoff: %p\n", begin_worker_handoff);
P("begin_worker_handoff: %d\n", *begin_worker_handoff);
P("mp_error_state: %p\n", mp_error_state);
P("mp_error_state: %d\n", *mp_error_state);
P("fresh_shm_head: %p\n", *fresh_shm_head);
P("===========================\n");
}
/* the following memory management implementation is extremely elementary
* and will fragment massively. It has no ability to reclaim deallocated
* memory. it is only implemented here so that I can manage dynamic allocations
* into shared memory in a way that drives forward this proof of concept.
*/
int arec_find(arec **head, void *target) {
if (head == NULL) {
P("alloc record list corrupt or uninitialized\n");
exit(1);
}
if (*head != NULL) {
return (*head)->obj == target ? 1 : arec_find(&(*head)->next, target);
}
return 0;
}
void arec_append(arec **head, void *target) {
if (head == NULL) {
P("alloc record list corrupt or uninitialized\n");
exit(1);
}
if (*head != NULL) return arec_append(&(*head)->next, target);
*head = *fresh_shm_head;
(*head)->obj = target;
(*head)->next = NULL;
*fresh_shm_head += sizeof(arec);
}
// doesnt actually delete anything, just skips a record.
void arec_delete(arec **head, void *target) {
if (head == NULL) {
P("alloc record list corrupt or uninitialized\n");
exit(1);
}
if (*head != NULL) {
if ((*head)->obj == target) {
arec *tmp = *head;
*head = tmp->next;
}
}
}
static void *_shmalloc_inner(size_t len) {
void *target = *fresh_shm_head;
*fresh_shm_head += len;
arec_append(alloc_recordings_head, target);
return target;
}
static void *shmalloc(size_t len, const char *file, int line) {
if (file && ( // need to catch the following
!strcmp(file, "ssl/ssl_lib.c") || // SSL
!strcmp(file, "crypto/bio/bio_lib.c") || // BIO
!strcmp(file, "ssl/ssl_sess.c") || // SSL_SESSION
0)) {
//P("caught SSL alloc @ %s:%d\n", file, line);
while(*alloc_lock);
*alloc_lock = 1;
void *t = _shmalloc_inner(len);
*alloc_lock = 0;
return t;
}
return malloc(len);
}
static void *shmrealloc(void *old, size_t len, const char *file, int line) {
if (arec_find(alloc_recordings_head, old)) {
//P("caught SSL realloc @ %s:%d\n", file, line);
arec_delete(alloc_recordings_head, old);
while(*alloc_lock);
*alloc_lock = 1;
void *t = _shmalloc_inner(len);
*alloc_lock = 0;
return t;
}
return realloc(old, len);
}
static void shmfree(void *mem, const char *file, int line) {
if (mem > shm_start && mem < (shm_start + SHM_LEN)) {
if (!arec_find(alloc_recordings_head, mem)) {
P("WARNING: MISSING RECORD FOR %p\n", mem);
}
//P("caught SSL delete @ %s:%d\n", file, line);
while(*alloc_lock);
*alloc_lock = 1;
arec_delete(alloc_recordings_head, mem);
*alloc_lock = 0;
return;
}
return free(mem);
}