#ifdef _WIN32
#include<ws2tcpip.h>
#include<windows.h>
#include"wepoll.c"
#else
#include<sys/socket.h>
#include<sys/epoll.h>
#include<netinet/in.h>
#include<netdb.h>
#include<sys/types.h>
#include<fcntl.h>
#include<unistd.h>
#endif

#include<stdbool.h>
#include<string.h>
#include<stdlib.h>
#include<stdio.h>
#include<assert.h>
#include<errno.h>

#include"picohttpparser.h"
#include"base64.h"
#include"teeny-sha1.c"

#ifdef _WIN32
typedef HANDLE EPoll;
typedef int Socket;
void *memmem(const void *haystack, size_t haystack_len, const void * const needle, const size_t needle_len) {
	if(haystack == NULL) return NULL;
	if(haystack_len == 0) return NULL;
	if(needle == NULL) return NULL;
	if(needle_len == 0) return NULL;
	
	for(const char *h = haystack; haystack_len >= needle_len; ++h, --haystack_len) {
		if(!memcmp(h, needle, needle_len)) {
			return (void*) h;
		}
	}
	
	return NULL;
}
#else
typedef int EPoll;
typedef int Socket;
#define closesocket close
#endif

typedef enum ClientType {
	UNKNOWN,
	CLI_STREAMER,
	CLI_VIEWER
} ClientType;

typedef enum ClientState {
	REQUEST,
	ACTIVE,
	WEBSOCKET,
} ClientState;

typedef struct {
	Socket fd;
	
	ClientType type;
	ClientState state;
	
	size_t len, prevlen, cap;
	uint8_t *buf;
	
	// Only for streamers
	struct phr_chunked_decoder chudec;
	
	// Only for websockets
	struct {
		int opcode;
		uint8_t *incoming;
		size_t incomingSz;
	} ws;
} Client;

typedef enum {
	LOADING_HEADER,
	STREAMING,
} StreamState;
static struct Stream {
	StreamState state;
	
	uint8_t *mkvHeader;
	size_t mkvHeaderSz;
	
	int stateChangeIdx;
} Stream;

static size_t clientsSz;
static Client **clients;

static char *ValidStreamPath = NULL;

static void consume(Client *cli, size_t n) {
	memmove(cli->buf, cli->buf + n, cli->len - n);
	cli->len -= n;
}

static int transmit(Client *cli, const char *buf, size_t sz) {
	while(sz) {
		ssize_t s = send(cli->fd, buf, sz, MSG_NOSIGNAL);
		
		if(s >= 0) {
			buf += s;
			sz -= s;
		} else {
			return 0;
		}
	}
	return 1;
}

static void transmit_all(const char *buf, size_t sz) {
	for(size_t i = 0; i < clientsSz; i++) {
		if(clients[i]->state == WEBSOCKET) {
			transmit(clients[i], buf, sz);
		}
	}
}

#define WS_BIN 2
#define WS_CLOSE 8
#define WS_FIN 128
#define WS_HEADER_MAX 10
static int ws_header(size_t sz, uint8_t hdr[static WS_HEADER_MAX]) {
	int i;
	hdr[0] = WS_BIN | WS_FIN;
	if(sz < 126) {
		hdr[1] = sz;
		i = 2;
	} else if(sz < 65536) {
		hdr[1] = 126;
		hdr[2] = sz >> 8;
		hdr[3] = sz & 0xFF;
		i = 4;
	} else {
		hdr[1] = 127;
		hdr[2] = (sz >> 56) & 0xFF;
		hdr[3] = (sz >> 48) & 0xFF;
		hdr[4] = (sz >> 40) & 0xFF;
		hdr[5] = (sz >> 32) & 0xFF;
		hdr[6] = (sz >> 24) & 0xFF;
		hdr[7] = (sz >> 16) & 0xFF;
		hdr[8] = (sz >> 8) & 0xFF;
		hdr[9] = (sz >> 0) & 0xFF;
		i = 10;
	}
	
	return i;
}

static void ws_send(Client *cli, const uint8_t *buf, size_t sz) {
	if(sz == 0) return;
	
	uint8_t wshdr[WS_HEADER_MAX];
	int wshdrsz = ws_header(sz, wshdr);
	
	transmit(cli, wshdr, wshdrsz);
	transmit(cli, buf, sz);
}

static void ws_broadcast(const uint8_t *buf, size_t sz) {
	if(sz == 0) return;
	
	uint8_t wshdr[WS_HEADER_MAX];
	int wshdrsz = ws_header(sz, wshdr);
	
	transmit_all(wshdr, wshdrsz);
	transmit_all(buf, sz);
}

static void stream_step(const uint8_t *newbuf, size_t newsz) {
	if(Stream.state == LOADING_HEADER) {
		Stream.mkvHeader = realloc(Stream.mkvHeader, Stream.mkvHeaderSz + newsz);
		memcpy(Stream.mkvHeader + Stream.mkvHeaderSz, newbuf, newsz);
		Stream.mkvHeaderSz += newsz;
		
		uint8_t *clusterEl = memmem(Stream.mkvHeader, Stream.mkvHeaderSz, "\x1F\x43\xB6\x75", 4);
		if(clusterEl) {
			ws_broadcast(Stream.mkvHeader, clusterEl - Stream.mkvHeader);
			ws_broadcast(clusterEl, Stream.mkvHeader + Stream.mkvHeaderSz - clusterEl);
			
			Stream.mkvHeaderSz = clusterEl - Stream.mkvHeader;
			Stream.state = STREAMING;
		}
	} else {
		int i;
		for(i = 0; i < newsz; i++) {
			if(newbuf[i] == "\x1A\x45\xDF\xA3"[Stream.stateChangeIdx]) { 
				Stream.stateChangeIdx++;
				
				if(Stream.stateChangeIdx == 4) {
					i++;
					Stream.stateChangeIdx = 0;
					Stream.state = LOADING_HEADER;
					break;
				}
			} else {
				Stream.stateChangeIdx = 0;
			}
		}
		
		if(Stream.state == LOADING_HEADER) {
			if(i > 4) {
				ws_broadcast(newbuf, i - 4);
			}
			
			Stream.mkvHeader = realloc(Stream.mkvHeader, Stream.mkvHeaderSz = 4 + (newsz - i));
			memcpy(Stream.mkvHeader, "\x1A\x45\xDF\xA3", 4);
			memcpy(Stream.mkvHeader + 4, newbuf + i, newsz - i);
		} else {
			ws_broadcast(newbuf, newsz);
		}
	}
}

static void receive_ws(Client *cli) {
}

static int handle(Client *cli) {
	while(cli->len != 0) {
		if(cli->state == REQUEST) {
			int minor_version;
			struct phr_header headers[96];
			const char *method, *path;
			size_t method_len, path_len, num_headers = sizeof(headers) / sizeof(headers[0]);
			int pret = phr_parse_request(cli->buf, cli->len, &method, &method_len, &path, &path_len, &minor_version, headers, &num_headers, cli->prevlen);
			
			if(pret == -1) {
				return 0;
			}
			
			if(pret == -2) {
				return 1;
			}
			
			bool connectionUpgrade = false;
			bool upgradeWebSocket = false;
			
			size_t wsAcceptLen;
			unsigned char *wsAccept = NULL;
			
			bool chunked = false;
			
			for(size_t i = 0; i < num_headers; i++) {
				if(strncmp(headers[i].name, "Upgrade", headers[i].name_len) == 0 && strncmp(headers[i].value, "websocket", headers[i].value_len) == 0) {
					upgradeWebSocket = true;
				} else if(strncmp(headers[i].name, "Connection", headers[i].name_len) == 0 && memmem(headers[i].value, headers[i].value_len, "Upgrade", 7)) {
					connectionUpgrade = true;
				} else if(strncmp(headers[i].name, "Transfer-Encoding", headers[i].name_len) == 0 && strncmp(headers[i].value, "chunked", headers[i].value_len) == 0) {
					chunked = true;
				} else if(strncmp(headers[i].name, "Sec-WebSocket-Key", headers[i].name_len) == 0) {
					size_t acceptbufsz = headers[i].value_len + 36;
					char acceptbuf[acceptbufsz];
					memcpy(acceptbuf, headers[i].value, headers[i].value_len);
					memcpy(acceptbuf + headers[i].value_len, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36);
					
					char sha1bin[20];
					char sha1hex[41];
					sha1digest(sha1bin, sha1hex, acceptbuf, acceptbufsz);
					
					wsAcceptLen = BASE64_ENCODE_OUT_SIZE(sizeof(sha1bin));
					wsAccept = malloc(wsAcceptLen);
					base64_encode(sha1bin, sizeof(sha1bin), wsAccept);
				}
			}
			
			if(path_len == strlen(ValidStreamPath) && strncmp(path, ValidStreamPath, path_len) == 0) {
				cli->type = CLI_STREAMER;
				cli->state = ACTIVE;
				
				if(upgradeWebSocket || connectionUpgrade || !chunked) {
					return 0;
				}
				
				printf("New streamer client\n");
			} else {
				cli->type = CLI_VIEWER;
				cli->state = WEBSOCKET;
				
				if(!upgradeWebSocket || !connectionUpgrade || chunked || !wsAccept) {
					return 0;
				}
				
				char buf[1024];
				int bufnum = snprintf(buf, sizeof(buf), "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %.*s\r\n\r\n", (int) wsAcceptLen, wsAccept);
				
				free(wsAccept);
				
				transmit(cli, buf, bufnum);
				
				printf("New WS client\n");
				
				if(Stream.state == STREAMING && Stream.mkvHeader) {
					printf("Sending header\n");
					ws_send(cli, Stream.mkvHeader, Stream.mkvHeaderSz);
				}
			}
			
			consume(cli, pret);
			
			cli->prevlen = 0;
		} else if(cli->state == ACTIVE) {
			size_t rsize = cli->len;
			int pret = phr_decode_chunked(&cli->chudec, cli->buf, &rsize);
			
			if(pret == -1) {
				return 0;
			}
			
			stream_step(cli->buf, rsize);
			
			cli->len = 0;
			
			if(pret == -2) {
				return 1;
			}
		} else if(cli->state == WEBSOCKET) {
			if(cli->len < 2) return 1;
			
			uint8_t framehdr = cli->buf[0];
			
			bool fin = framehdr & 128;
			int opcode = framehdr & 15;
			
			if(cli->ws.opcode == 0 && opcode) {
				cli->ws.opcode = opcode;
			}
			
			size_t payloadSz = 0;
			int i;
			
			uint8_t payload0 = cli->buf[1] & 127;
			if(payload0 < 126) {
				payloadSz = payload0;
				
				i = 2;
			} else if(payload0 == 126) {
				if(cli->len < 4) return 1;
				
				payloadSz = (cli->buf[2] << 8) + cli->buf[3];
				
				i = 4;
			} else if(payload0 == 127) {
				if(cli->len < 10) return 1;
				
				payloadSz = ((uint64_t) cli->buf[2] << 56) + ((uint64_t) cli->buf[3] << 48) + ((uint64_t) cli->buf[4] << 40) + ((uint64_t) cli->buf[5] << 32) + ((uint64_t) cli->buf[6] << 24) + ((uint64_t) cli->buf[7] << 16) + ((uint64_t) cli->buf[8] << 8) + ((uint64_t) cli->buf[9] << 0);
				
				i = 10;
			}
			
			if(payloadSz > 100) {
				// Literally just kick
				return 0;
			}
			
			if(cli->len < i + 4 + payloadSz) {
				return 1;
			}
			
			uint8_t mask[4] = {cli->buf[i], cli->buf[i + 1], cli->buf[i + 2], cli->buf[i + 3]};
			
			for(size_t b = 0; b < payloadSz; b++) {
				cli->buf[i + 4 + b] ^= mask[b % 4];
			}
			
			cli->ws.incoming = realloc(cli->ws.incoming, cli->ws.incomingSz + payloadSz);
			memcpy(cli->ws.incoming + cli->ws.incomingSz, cli->buf + i + 4, payloadSz);
			
			if(fin) {
				receive_ws(cli);
				
				if(cli->ws.opcode == WS_CLOSE) {
					return 0;
				}
				
				cli->ws.incomingSz = 0;
				cli->ws.opcode = 0;
			}
		}
	}
	
	return 1;
}

static void rem_cli(Client *cli) {
	for(size_t i = 0; i < clientsSz; i++) {
		if(clients[i] == cli) {
			memmove(clients + i, clients + i + 1, sizeof(*clients) * (clientsSz - i - 1));
			clientsSz--;
			return;
		}
	}
}

static Socket ServSock;
static EPoll EP;

static int Argc;
static char **Argv;

static const char *get_arg(const char *key, const char *def) {
	int z = strlen(key);
	
	for(size_t i = 1; i < Argc; i++) {
		if(strlen(Argv[i]) > z && strstr(Argv[i], key) == Argv[i] && Argv[i][z] == '=') {
			return Argv[i] + z + 1;
		}
	}
	
	return def;
}

static bool get_arg_bool(const char *key) {
	const char *val = get_arg(key, "0");
	return strtol(val, NULL, 0);
}

int main(int argc, char **argv) {
	Argc = argc, Argv = argv;
	
	const char *streamkey = get_arg("key", NULL);
	if(!streamkey) {
		puts("Missing stream key parameter key=...");
		return 0;
	}
	
	ValidStreamPath = calloc(1, 6 + strlen(streamkey) + 1);
	strcat(ValidStreamPath, "/push/");
	strcat(ValidStreamPath, streamkey);
	
	#ifdef _WIN32
	WSAStartup(MAKEWORD(2, 2), &(WSADATA) {});
	
	ServSock = socket(AF_INET6, SOCK_STREAM, 0);
	ioctlsocket(ServSock, FIONBIO, &(u_long) {1});
	#else
	ServSock = socket(AF_INET6, SOCK_STREAM | SOCK_NONBLOCK, 0);
	#endif
	EP = epoll_create1(0);
	
	assert(ServSock != -1);
	
	if(get_arg_bool("reuseaddr")) {
		setsockopt(ServSock, SOL_SOCKET, SO_REUSEADDR, &(int) {1}, sizeof(int));
	}
	setsockopt(ServSock, IPPROTO_IPV6, IPV6_V6ONLY, &(int) {0}, sizeof(int));
	
	struct addrinfo *res = NULL;
	assert(getaddrinfo(NULL, get_arg("port", "25404"), &(struct addrinfo) {.ai_flags = AI_PASSIVE, .ai_family = AF_INET6}, &res) == 0);
	
	assert(bind(ServSock, res->ai_addr, res->ai_addrlen) >= 0);
	
	freeaddrinfo(res);
	
	assert(listen(ServSock, 16) >= 0);
	
	epoll_ctl(EP, EPOLL_CTL_ADD, ServSock, &(struct epoll_event) {.events = EPOLLIN | EPOLLOUT, .data = {.fd = ServSock}});
	
	while(1) {
		#define BUFSZ 8192
		char buf[BUFSZ];
		
		#define EPOLL_EVS 2048
		struct epoll_event events[EPOLL_EVS];
		int nfds = epoll_wait(EP, events, EPOLL_EVS, -1);
		
		for(int i = 0; i < nfds; i++) {
			if(events[i].data.fd == ServSock) {
				struct sockaddr_storage addr;
				socklen_t addrlen;
				
				Socket clisock = accept(ServSock, (struct sockaddr*) &addr, &addrlen);
				
				#ifdef _WIN32
				ioctlsocket(clisock, FIONBIO, &(u_long) {1});
				#else
				if(fcntl(clisock, F_SETFL, fcntl(clisock, F_GETFL, 0) | O_NONBLOCK) == -1) {
					closesocket(clisock);
					continue;
				}
				#endif
				
				Client *cli = calloc(1, sizeof(*cli));
				cli->fd = clisock;
				cli->len = 0;
				cli->buf = malloc(cli->cap = 8192);
				
				epoll_ctl(EP, EPOLL_CTL_ADD, clisock, &(struct epoll_event) {.events = EPOLLIN | EPOLLRDHUP | EPOLLHUP, .data = {.ptr = cli}});
				
				clients = realloc(clients, sizeof(*clients) * (clientsSz + 1));
				clients[clientsSz++] = cli;
			} else {
				Client *cli = events[i].data.ptr;
				
				bool forceclose = 0;
				
				if(events[i].events & EPOLLIN) {
					while(1) {
						ssize_t readcount = recv(cli->fd, buf, sizeof(buf), 0);
						
						if(readcount <= 0) {
							if(errno != EAGAIN && errno != EWOULDBLOCK) {
								forceclose = true;
							}
							break;
						}
						
						if(cli->len + readcount > cli->cap) {
							cli->buf = realloc(cli->buf, cli->cap = ((cli->len + readcount + 4095) & ~4095));
						}
						
						memcpy(cli->buf + cli->len, buf, readcount);
						
						cli->prevlen = cli->len;
						cli->len += readcount;
					}
					
					if(handle(cli) == 0) {
						forceclose = true;
					}
				}
				
				if(forceclose || (events[i].events & (EPOLLRDHUP | EPOLLHUP))) {
					epoll_ctl(EP, EPOLL_CTL_DEL, cli->fd, NULL);
					closesocket(cli->fd);
					
					rem_cli(cli);
					
					free(cli->buf);
					
					free(cli->ws.incoming);
					
					free(cli);
					
					printf("Client left, now at %lu\n", clientsSz);
				}
			}
		}
	}
}