554 lines
14 KiB
C
554 lines
14 KiB
C
#include <arpa/inet.h>
|
|
#include <sys/socket.h>
|
|
#include <sys/mman.h>
|
|
#include <sys/types.h>
|
|
#include <sys/wait.h>
|
|
#include <netdb.h>
|
|
#include <unistd.h>
|
|
#include <errno.h>
|
|
#include <fcntl.h>
|
|
|
|
#include <openssl/bio.h>
|
|
#include <openssl/err.h>
|
|
#include <openssl/ssl.h>
|
|
#include <openssl/crypto.h>
|
|
|
|
#define SHM_LEN 262144
|
|
#define LISTEN_PORT 9125
|
|
#define CERT_FILE "cert.pem"
|
|
#define KEY_FILE "key.pem"
|
|
|
|
#define P(...) printf(__VA_ARGS__); fflush(stdout);
|
|
|
|
struct alloc_record {
|
|
void *obj;
|
|
struct alloc_record *next;
|
|
};
|
|
typedef struct alloc_record arec;
|
|
|
|
typedef struct {
|
|
int pid;
|
|
int send_fd;
|
|
int recv_fd;
|
|
} pinfo_t;
|
|
|
|
int listen_fd;
|
|
void *shm_start;
|
|
void **fresh_shm_head;
|
|
char *alloc_lock;
|
|
arec **alloc_recordings_head;
|
|
char *client_connect_ready;
|
|
char *begin_worker_handoff;
|
|
char *mp_error_state;
|
|
char **ssl_buf_mem_handoff;
|
|
SSL_CTX *ctx;
|
|
SSL **client_con_ssl_instance;
|
|
|
|
pinfo_t *young, *elder, *client;
|
|
|
|
// process loops
|
|
void client_loop();
|
|
void elder_worker_loop();
|
|
void young_worker_loop();
|
|
|
|
// mem mgmt
|
|
static void *shmalloc(size_t len, const char *file, int line);
|
|
static void *shmrealloc(void *old, size_t len, const char *file, int line);
|
|
static void shmfree(void *mem, const char *file, int line);
|
|
|
|
// utility functions
|
|
pinfo_t fork1(void (*cb)(void));
|
|
void prn_state(const char *procname);
|
|
int arec_find(arec **head, void *target);
|
|
void arec_append(arec **head, void *target);
|
|
void arec_delete(arec **head, void *target);
|
|
|
|
int main() {
|
|
void *ssl_memory = shm_start = mmap(NULL, SHM_LEN, PROT_READ | PROT_WRITE,
|
|
MAP_SHARED | MAP_ANONYMOUS, -1, 0);
|
|
if (ssl_memory == MAP_FAILED) {
|
|
P("failed to map shared memory for openssl");
|
|
return 1;
|
|
}
|
|
|
|
// prep fresh_shm_head
|
|
fresh_shm_head = ssl_memory;
|
|
*fresh_shm_head = (ssl_memory + sizeof(void *));
|
|
|
|
// prep alloc_recordings_head
|
|
alloc_recordings_head = *fresh_shm_head;
|
|
*alloc_recordings_head = NULL;
|
|
*fresh_shm_head += sizeof(void *);
|
|
|
|
// prep client_connect_ready
|
|
client_connect_ready = *fresh_shm_head;
|
|
*client_connect_ready = 0;
|
|
*fresh_shm_head += sizeof(char);
|
|
|
|
// prep begin_worker_handoff
|
|
begin_worker_handoff = *fresh_shm_head;
|
|
*begin_worker_handoff = 0;
|
|
*fresh_shm_head += sizeof(char);
|
|
|
|
// prep mp_error_state
|
|
mp_error_state = *fresh_shm_head;
|
|
*mp_error_state = 0;
|
|
*fresh_shm_head += sizeof(char);
|
|
|
|
// prep client_ssl_instance
|
|
client_con_ssl_instance = *fresh_shm_head;
|
|
*client_con_ssl_instance = NULL;
|
|
*fresh_shm_head += sizeof(void *);
|
|
|
|
// prep pinfo_t's
|
|
young = *fresh_shm_head;
|
|
young->pid = 0;
|
|
young->recv_fd = 0;
|
|
young->send_fd = 0;
|
|
*fresh_shm_head += sizeof(pinfo_t);
|
|
|
|
elder = *fresh_shm_head;
|
|
elder->pid = 0;
|
|
elder->recv_fd = 0;
|
|
elder->send_fd = 0;
|
|
*fresh_shm_head += sizeof(pinfo_t);
|
|
|
|
client = *fresh_shm_head;
|
|
client->pid = 0;
|
|
client->recv_fd = 0;
|
|
client->send_fd = 0;
|
|
*fresh_shm_head += sizeof(pinfo_t);
|
|
|
|
// prep alloc lock
|
|
alloc_lock = *fresh_shm_head;
|
|
*alloc_lock = 0;
|
|
*fresh_shm_head += sizeof(char);
|
|
|
|
// prep ssl_buf_mem_handoff
|
|
ssl_buf_mem_handoff = *fresh_shm_head;
|
|
*ssl_buf_mem_handoff = NULL;
|
|
*fresh_shm_head += sizeof(void *);
|
|
|
|
CRYPTO_set_mem_functions(shmalloc, shmrealloc, shmfree);
|
|
SSL_load_error_strings();
|
|
SSL_library_init();
|
|
OpenSSL_add_all_algorithms();
|
|
|
|
ctx = SSL_CTX_new(TLS_server_method());
|
|
if (!ctx) {
|
|
P("failed to create SSL Context\n");
|
|
return 1;
|
|
}
|
|
|
|
if (SSL_CTX_use_certificate_file(ctx, CERT_FILE, SSL_FILETYPE_PEM) <= 0) {
|
|
P("failed to load cert %s\n", CERT_FILE);
|
|
ERR_print_errors_fp(stderr);
|
|
return 1;
|
|
}
|
|
|
|
if (SSL_CTX_use_PrivateKey_file(ctx, KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
|
|
P("failed to load key %s\n", KEY_FILE);
|
|
ERR_print_errors_fp(stderr);
|
|
return 1;
|
|
}
|
|
|
|
// surely theres something better than this...
|
|
int on = 1;
|
|
if ((listen_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
|
|
P("failed to open new socket\n");
|
|
return 1;
|
|
}
|
|
|
|
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
|
|
P("failed to set SO_REUSEADDR");
|
|
return 1;
|
|
}
|
|
|
|
struct sockaddr_in address = {
|
|
AF_INET, htons(LISTEN_PORT), { INADDR_ANY }
|
|
};
|
|
if (bind(listen_fd, (struct sockaddr *) &address,
|
|
sizeof(address)) < 0) {
|
|
P("failed to bind to port 9123\n");
|
|
return 1;
|
|
}
|
|
|
|
listen(listen_fd, 5);
|
|
|
|
*client = fork1(&client_loop);
|
|
*elder = fork1(&elder_worker_loop);
|
|
*young = fork1(&young_worker_loop);
|
|
// wait for all children, kill all process in own process group if an error occurs
|
|
while (wait(NULL) != -1) if (*mp_error_state) kill(0, SIGKILL);
|
|
|
|
SSL_CTX_free(ctx);
|
|
ERR_free_strings();
|
|
EVP_cleanup();
|
|
}
|
|
|
|
void client_loop() {
|
|
while (!*client_connect_ready);
|
|
P("from client: client connect ready is set!\n");
|
|
|
|
SSL_CTX *client_ctx = SSL_CTX_new(TLS_client_method());
|
|
if (!client_ctx) {
|
|
P("from client: failed to create new ctx\n");
|
|
*mp_error_state = 1;
|
|
ERR_print_errors_fp(stderr);
|
|
exit(1);
|
|
}
|
|
|
|
SSL *client = SSL_new(client_ctx);
|
|
struct sockaddr_in serv_addr;
|
|
serv_addr.sin_family = AF_INET;
|
|
serv_addr.sin_port = htons(LISTEN_PORT);
|
|
if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) {
|
|
P("\nInvalid address/ Address not supported \n");
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
int client_fd;
|
|
if ((client_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
|
|
P("from client: failed to open socket\n");
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
if (connect(client_fd, (struct sockaddr*) &serv_addr, sizeof(serv_addr)) < 0) {
|
|
P("from client: connection failed \n");
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
SSL_set_fd(client, client_fd);
|
|
if (SSL_connect(client) != 1) {
|
|
P("from client: failed to SSL_connect\n");
|
|
*mp_error_state = 1;
|
|
ERR_print_errors_fp(stderr);
|
|
exit(1);
|
|
}
|
|
|
|
char buf[9] = {0}; // CERT_FILE and null terminator
|
|
int r = SSL_read(client, buf, 9);
|
|
if (r > 0 && r != 9) {
|
|
P("from client: warning: only %d bytes read\n", r);
|
|
} else if (r <= 0) {
|
|
P("from client: failed to SSL_read\n");
|
|
ERR_print_errors_fp(stderr);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
} else if (strcmp(CERT_FILE, buf)) {
|
|
P("from client: warning: unexpected data: %s\n", buf);
|
|
} else {
|
|
P("from client: received expected data from elder worker\n");
|
|
}
|
|
|
|
SSL_write(client, KEY_FILE, sizeof(KEY_FILE));
|
|
*client_connect_ready = 0;
|
|
while (!*client_connect_ready);
|
|
|
|
// ++++++ SHOULD BE READING FROM NEW WORKER HERE
|
|
char buf2[9] = {0};
|
|
int r2 = SSL_read(client, buf, 9);
|
|
if (r2 > 0 && r2 != 9) {
|
|
P("from client: truncated message from young worker %d\n", r2);
|
|
} else if (r2 <= 0) {
|
|
P("from client: failed to SSL_read from young worker\n");
|
|
ERR_print_errors_fp(stderr);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
} else if (strcmp(CERT_FILE, buf2)) {
|
|
P("from client: warning: unexpected data from young worker: %s\n", buf2);
|
|
} else {
|
|
P("from client: SUCCESS READING MESSAGE FROM YOUNG WORKER\n");
|
|
}
|
|
}
|
|
|
|
void elder_worker_loop() {
|
|
*client_connect_ready = 1;
|
|
|
|
struct sockaddr_in addr;
|
|
uint len = sizeof(addr); // this is awful tbh
|
|
int client_fd = accept(listen_fd, (struct sockaddr *) &addr, &len);
|
|
if (client_fd < 0) {
|
|
P("from elder: failed to accept incoming connection\n");
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
(*client_con_ssl_instance) = SSL_new(ctx);
|
|
SSL_set_fd(*client_con_ssl_instance, client_fd);
|
|
if (SSL_accept(*client_con_ssl_instance) <= 0) {
|
|
P("from elder: failed to SSL_accept\n");
|
|
ERR_print_errors_fp(stderr);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE));
|
|
char buf[8] = {0}; // KEY_FILE and null terminator
|
|
int r = SSL_read(*client_con_ssl_instance, buf, 8);
|
|
if (r < 0 && r != 8) {
|
|
P("from elder: warning: only %d bytes read\n", r);
|
|
} else if (r <= 0) {
|
|
P("fomr elder: failed to SSL_read\n");
|
|
ERR_print_errors_fp(stderr);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
} else if (strcmp(KEY_FILE, buf)) {
|
|
P("from elder: warning: unexpected data: %s\n", buf);
|
|
} else {
|
|
P("from elder: received expected data from client worker\n");
|
|
}
|
|
|
|
// wait for young worker to call recvmsg
|
|
while (!*begin_worker_handoff);
|
|
|
|
/* Allocate a char array of suitable size to hold the fd (int).
|
|
* However, since this buffer is in reality a 'struct cmsghdr', use a
|
|
* union to ensure that it is suitably aligned.
|
|
*/
|
|
union {
|
|
char buf[CMSG_SPACE(sizeof(int))];
|
|
struct cmsghdr align;
|
|
} control_msg;
|
|
|
|
// on linux we must send at least one byte of data in order to send
|
|
// the ancillary data (control messge) so here is some random int
|
|
int data = 0;
|
|
struct iovec iov;
|
|
iov.iov_base = &data;
|
|
iov.iov_len = 1;
|
|
struct msghdr msg = {
|
|
NULL, 0, // address (our socket is already connected)
|
|
&iov, 1, // one byte is even less than an int.
|
|
control_msg.buf, sizeof(control_msg.buf),
|
|
0
|
|
};
|
|
|
|
struct cmsghdr *cm = CMSG_FIRSTHDR(&msg);
|
|
cm->cmsg_level = SOL_SOCKET;
|
|
cm->cmsg_type = SCM_RIGHTS;
|
|
cm->cmsg_len = CMSG_LEN(sizeof(int));
|
|
memcpy(CMSG_DATA(cm), &client_fd, sizeof(int));
|
|
|
|
int res = sendmsg(elder->send_fd, &msg, 0);
|
|
if (res == -1) {
|
|
P("from elder: failed to sendmsg. errno = %d\n", errno);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
} else if (res != msg.msg_iovlen) {
|
|
P("from elder: sendmsg transmission truncated (len: %d) (errno: %d)\n",
|
|
res, errno);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
P("from elder: sent over my fd\n");
|
|
/* TODO: need to be able to send over the init buf somehow
|
|
struct ssl_connection_st *s = (struct ssl_connection_st *) *client_con_ssl_instance;
|
|
(*ssl_buf_mem_handoff) = _shmalloc_inner(s->init_buf->len);
|
|
*/
|
|
|
|
while (1);
|
|
}
|
|
|
|
void young_worker_loop() {
|
|
int data, fd;
|
|
struct iovec iov;
|
|
struct msghdr msg;
|
|
|
|
union {
|
|
char buf[CMSG_SPACE(sizeof(int))];
|
|
struct cmsghdr align;
|
|
} control_msg;
|
|
|
|
msg.msg_name = NULL;
|
|
msg.msg_namelen = 0;
|
|
msg.msg_iov = &iov;
|
|
msg.msg_iovlen = 1;
|
|
iov.iov_base = &data;
|
|
iov.iov_len = sizeof(int);
|
|
msg.msg_control = control_msg.buf;
|
|
msg.msg_controllen = sizeof(control_msg.buf);
|
|
|
|
P("from young: triggering handoff\n");
|
|
*begin_worker_handoff = 1;
|
|
|
|
int res = recvmsg(elder->recv_fd, &msg, 0);
|
|
if (res == -1) {
|
|
P("from young: failed to recvmsg. errno = %d\n", errno);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
P("from young: received message from elder\n");
|
|
|
|
struct cmsghdr *control_head = CMSG_FIRSTHDR(&msg);
|
|
if (control_head == NULL ||
|
|
control_head->cmsg_len != CMSG_LEN(sizeof(int)) ||
|
|
control_head->cmsg_level != SOL_SOCKET ||
|
|
control_head->cmsg_type != SCM_RIGHTS) {
|
|
P("from young: incoming message was invalid\n");
|
|
exit(1);
|
|
}
|
|
|
|
struct cmsghdr *cm = CMSG_FIRSTHDR(&msg);
|
|
int connection_fd = *CMSG_DATA(cm);
|
|
P("from young: received fd from elder: %d\n", connection_fd);
|
|
|
|
if(fcntl(connection_fd, F_GETFD) == -1) {
|
|
P("from young: inherited fd is not valid\n");
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
|
|
// must overwrite fd without SSL_set_fd. we dont want to deallocate the orig BIO
|
|
P("from young: getting BIOs... ");
|
|
int close = 0;
|
|
BIO *rbio, *wbio;
|
|
rbio = SSL_get_rbio(*client_con_ssl_instance);
|
|
wbio = SSL_get_wbio(*client_con_ssl_instance);
|
|
P(" Got BIOs (%p, %p)\n", rbio, wbio);
|
|
BIO_set_fd(rbio, connection_fd, BIO_NOCLOSE);
|
|
BIO_reset(rbio);
|
|
P("from young: set new fd in BIO\n");
|
|
|
|
if(SSL_write(*client_con_ssl_instance, CERT_FILE, sizeof(CERT_FILE)) <= 0) {
|
|
P("from young: failed to SSL_write\n");
|
|
ERR_print_errors_fp(stderr);
|
|
*mp_error_state = 1;
|
|
exit(1);
|
|
}
|
|
P("from young: SUCCESS WRITING TO INHERITED CONNECTION");
|
|
*client_connect_ready = 1;
|
|
}
|
|
|
|
pinfo_t fork1(void (*cb)()) {
|
|
pinfo_t ret;
|
|
if(socketpair(AF_UNIX, SOCK_STREAM, 0, &ret.send_fd) != 0) {
|
|
P("failed to create socket pair before fork()");
|
|
exit(-1);
|
|
}
|
|
|
|
P("new sock pair: %d and %d\n", ret.send_fd, ret.recv_fd);
|
|
|
|
ret.pid = fork();
|
|
if (ret.pid) return ret;
|
|
(*cb)();
|
|
exit(0);
|
|
}
|
|
|
|
void prn_state(const char *procname) {
|
|
P("======== %s ========\n", procname);
|
|
P("client_connect_ready: %p\n", client_connect_ready);
|
|
P("client_connect_ready: %d\n", *client_connect_ready);
|
|
P("begin_worker_handoff: %p\n", begin_worker_handoff);
|
|
P("begin_worker_handoff: %d\n", *begin_worker_handoff);
|
|
P("mp_error_state: %p\n", mp_error_state);
|
|
P("mp_error_state: %d\n", *mp_error_state);
|
|
P("fresh_shm_head: %p\n", *fresh_shm_head);
|
|
P("===========================\n");
|
|
}
|
|
|
|
/* the following memory management implementation is extremely elementary
|
|
* and will fragment massively. It has no ability to reclaim deallocated
|
|
* memory. it is only implemented here so that I can manage dynamic allocations
|
|
* into shared memory in a way that drives forward this proof of concept.
|
|
*/
|
|
|
|
int arec_find(arec **head, void *target) {
|
|
if (head == NULL) {
|
|
P("alloc record list corrupt or uninitialized\n");
|
|
exit(1);
|
|
}
|
|
|
|
if (*head != NULL) {
|
|
return (*head)->obj == target ? 1 : arec_find(&(*head)->next, target);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
void arec_append(arec **head, void *target) {
|
|
if (head == NULL) {
|
|
P("alloc record list corrupt or uninitialized\n");
|
|
exit(1);
|
|
}
|
|
|
|
if (*head != NULL) return arec_append(&(*head)->next, target);
|
|
|
|
*head = *fresh_shm_head;
|
|
(*head)->obj = target;
|
|
(*head)->next = NULL;
|
|
*fresh_shm_head += sizeof(arec);
|
|
}
|
|
|
|
// doesnt actually delete anything, just skips a record.
|
|
void arec_delete(arec **head, void *target) {
|
|
if (head == NULL) {
|
|
P("alloc record list corrupt or uninitialized\n");
|
|
exit(1);
|
|
}
|
|
|
|
if (*head != NULL) {
|
|
if ((*head)->obj == target) {
|
|
arec *tmp = *head;
|
|
*head = tmp->next;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void *_shmalloc_inner(size_t len) {
|
|
void *target = *fresh_shm_head;
|
|
*fresh_shm_head += len;
|
|
arec_append(alloc_recordings_head, target);
|
|
return target;
|
|
}
|
|
|
|
static void *shmalloc(size_t len, const char *file, int line) {
|
|
if (file && ( // need to catch the following
|
|
!strcmp(file, "ssl/ssl_lib.c") || // SSL
|
|
!strcmp(file, "crypto/bio/bio_lib.c") || // BIO
|
|
!strcmp(file, "ssl/ssl_sess.c") || // SSL_SESSION
|
|
0)) {
|
|
//P("caught SSL alloc @ %s:%d\n", file, line);
|
|
while(*alloc_lock);
|
|
*alloc_lock = 1;
|
|
void *t = _shmalloc_inner(len);
|
|
*alloc_lock = 0;
|
|
return t;
|
|
}
|
|
|
|
return malloc(len);
|
|
}
|
|
|
|
static void *shmrealloc(void *old, size_t len, const char *file, int line) {
|
|
if (arec_find(alloc_recordings_head, old)) {
|
|
//P("caught SSL realloc @ %s:%d\n", file, line);
|
|
arec_delete(alloc_recordings_head, old);
|
|
while(*alloc_lock);
|
|
*alloc_lock = 1;
|
|
void *t = _shmalloc_inner(len);
|
|
*alloc_lock = 0;
|
|
return t;
|
|
}
|
|
|
|
return realloc(old, len);
|
|
}
|
|
|
|
static void shmfree(void *mem, const char *file, int line) {
|
|
if (mem > shm_start && mem < (shm_start + SHM_LEN)) {
|
|
if (!arec_find(alloc_recordings_head, mem)) {
|
|
P("WARNING: MISSING RECORD FOR %p\n", mem);
|
|
}
|
|
//P("caught SSL delete @ %s:%d\n", file, line);
|
|
while(*alloc_lock);
|
|
*alloc_lock = 1;
|
|
arec_delete(alloc_recordings_head, mem);
|
|
*alloc_lock = 0;
|
|
return;
|
|
}
|
|
|
|
return free(mem);
|
|
}
|