/*-
 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
 *
 * Copyright (c) 2021 Tobias Kortkamp <tobik@FreeBSD.org>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */
#include "config.h"

#include <sys/param.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include "array.h"
#include "flow.h"
#include "libgrapheme/grapheme.h"
#include "mem.h"
#include "mempool.h"
#include "peg.h"
#include "queue.h"
#include "set.h"
#include "stack.h"
#include "str.h"
#include "trait/compare.h"
#include "trait/peg.h"

struct PEGError {
	size_t index;
	size_t pos;
	const char *rule;
	const char *msg;
};

struct PEG {
	const char *const buf;
	const size_t len;
	size_t pos;
	size_t depth;

	struct {
		struct Mempool *pool;
		struct Set *set;
		struct Queue *queue;
		struct Stack *pos;
		bool enabled;
	} captures;

	bool debug;
	struct Array *errors;
	size_t error_index;
	struct Queue *rule_trace;

	struct Mempool *pool;
};

// Prototypes
static DECLARE_COMPARE(compare_capture);
static DECLARE_COMPARE(compare_error);
static void peg_line_col_at_pos(struct PEG *, size_t, size_t *, size_t *);
static void peg_reset(struct PEG *);

// Constants
static const size_t PEG_MAX_DEPTH = 10000;
static const size_t PEG_MAX_ERRORS = 4;
struct CompareTrait *peg_capture_compare = &(struct CompareTrait){
	.compare = compare_capture,
	.compare_userdata = NULL,
};
static struct CompareTrait *peg_error_compare = &(struct CompareTrait){
	.compare = compare_error,
	.compare_userdata = NULL,
};

#define MATCHER_INIT() \
	size_t MATCHER_INIT_captures_queue_len = queue_len(peg->captures.queue); \
	size_t MATCHER_INIT_rule_trace_len = queue_len(peg->rule_trace); \
	do { \
		peg->depth++; \
		if (peg->depth > PEG_MAX_DEPTH || peg->pos > peg->len) { \
			MATCHER_RETURN(0); \
		} \
		if (peg->debug) { \
			queue_push(peg->rule_trace, (char *)rule); \
		} \
	} while (0)
#define MATCHER_POP(x) \
do { \
	for (size_t i = MATCHER_INIT_captures_queue_len; i < queue_len(peg->captures.queue); i++) { \
		struct PEGCapture *capture = queue_dequeue(peg->captures.queue); \
		set_remove(peg->captures.set, capture); \
		mempool_release(peg->captures.pool, capture); \
	} \
	if (peg->debug) { \
		for (size_t i = MATCHER_INIT_rule_trace_len; i < queue_len(peg->rule_trace); i++) { \
			queue_dequeue(peg->rule_trace); \
		} \
	} \
} while(0)
#define MATCHER_RETURN(x) \
do { \
	peg->depth--; \
	if (!(x)) { \
		MATCHER_POP(); \
	} \
	return (x); \
} while (0)

DEFINE_COMPARE(compare_capture, struct PEGCapture, void)
{
	if (a->pos < b->pos) {
		return -1;
	} else if (a->pos > b->pos) {
		return 1;
	} else if (a->len < b->len) {
		return -1;
	} else if (a->len > b->len) {
		return 1;
	} else if (a->state < b->state) {
		return -1;
	} else if (a->state > b->state) {
		return 1;
	} else if (a->tag < b->tag) {
		return -1;
	} else if (a->tag > b->tag) {
		return 1;
	} else {
		return 0;
	}
}

DEFINE_COMPARE(compare_error, struct PEGError, void)
{
	if (a->pos < b->pos) {
		return 1;
	} else if (a->pos > b->pos) {
		return -1;
	} else if (this != NULL && a->index < b->index) {
		return -1;
	} else if (this != NULL && a->index > b->index) {
		return 1;
	} else if (a->rule == NULL && b->rule == NULL) {
		return 0;
	} else if (a->rule == NULL) {
		return 1;
	} else if (b->rule == NULL) {
		return -1;
	} else {
		return strcmp(a->rule, b->rule);
	}
}

bool
peg_match(struct PEG *peg, RuleFn rulefn, struct PEGCaptureMachine *capture_machine)
{
	peg_reset(peg);

	peg->captures.enabled = capture_machine != NULL;
	bool result = peg_match_rule(peg, ":main", rulefn, NULL);

	if (result && capture_machine) {
		struct PEGCapture accept_capture = {
			.peg = peg,
			.buf = peg->buf,
			.pos = 0,
			.len = peg->pos,
			.tag = -1,
			.state = 0, // Accept state
		};
		queue_push(peg->captures.queue, &accept_capture);
		if (capture_machine->single) {
			QUEUE_FOREACH(peg->captures.queue, struct PEGCapture *, capture) {
				unless (capture_machine->single(capture, capture_machine->userdata)) {
					result = false;
					break;
				}
			}
		}
		if (capture_machine->multi) {
			result = capture_machine->multi(peg->captures.queue, capture_machine->userdata);
		}
	}

	if (peg->debug && queue_len(peg->rule_trace) > 0) {
		char *rule = queue_pop(peg->rule_trace);
		printf("%s", rule);
		while ((rule = queue_pop(peg->rule_trace))) {
			printf(" -> %s", rule);
		}
		printf(" -> end@%zu\n", peg->pos);
	}

	peg_reset(peg);

	return result;
}

bool
peg_match_atleast(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata, uint32_t n)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	uint32_t i;
	for (i = 0; ; i++) {
		if (!rulefn(peg, userdata)) {
			break;
		}
	}
	if (i < n) {
		peg->pos = pos;
		MATCHER_RETURN(false);
	}
	MATCHER_RETURN(true);
}

bool
peg_match_between(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata, uint32_t a, uint32_t b)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	uint32_t i;
	for (i = 0; i <= b; i++) {
		if (!rulefn(peg, userdata)) {
			break;
		}
	}
	if (i >= a && i <= b) {
		MATCHER_RETURN(true);
	} else {
		peg->pos = pos;
		MATCHER_RETURN(false);
	}
}

bool
peg_match_capture_start(struct PEG *peg)
{
	if (peg->captures.enabled) {
		stack_push(peg->captures.pos, (void *)(uintptr_t)peg->pos);
	}
	return true;
}

bool
peg_match_capture_end(struct PEG *peg, uint32_t tag, uint32_t state, bool retval)
{
	if (peg->captures.enabled && stack_len(peg->captures.pos) > 0) {
		size_t start = (size_t)stack_pop(peg->captures.pos);
		if (retval) {
			size_t len = peg->pos - start;
			struct PEGCapture c = { .tag = tag, .state = state, .pos = start, .len = len };
			if (!set_get(peg->captures.set, &c)) {
				struct PEGCapture *capture = mempool_alloc(peg->captures.pool, sizeof(struct PEGCapture));
				capture->tag = tag;
				capture->state = state;
				capture->buf = peg->buf + start;
				capture->pos = start;
				capture->len = len;
				capture->peg = peg;
				queue_push(peg->captures.queue, capture);
				set_add(peg->captures.set, capture);
			}
		}
	}
	return retval;
}

bool
peg_match_chars(struct PEG *peg, const char *rule, uint32_t chars[], size_t len)
{
	MATCHER_INIT();
	for (size_t i = 0; i < len; i++) {
		char needle[16 + 1];
		size_t len = grapheme_encode_utf8(chars[i], needle, sizeof(needle));
		if (len > sizeof(needle)) {
			MATCHER_RETURN(false);
		}
		if ((peg->len - peg->pos) >= len &&
		    strncmp(peg->buf + peg->pos, needle, len) == 0) {
			peg->pos += len;
			MATCHER_RETURN(true);
		}
	}
	MATCHER_RETURN(false);
}

bool
peg_match_eos(struct PEG *peg, const char *rule)
{
	MATCHER_INIT();
	MATCHER_RETURN(peg->pos == peg->len);
}

bool
peg_match_error(struct PEG *peg, const char *rule, const char *msg)
{
	struct PEGError key = { .rule = rule, .msg = msg, .pos = peg->pos };
	if (array_find(peg->errors, &key, peg_error_compare) == -1) {
		struct PEGError *err = array_get(peg->errors, PEG_MAX_ERRORS - 1);
		err->index = peg->error_index++;
		err->rule = rule;
		err->msg = msg;
		err->pos = peg->pos;
		array_sort(peg->errors, peg_error_compare);
	}
	return false;
}

bool
peg_match_lookahead(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	if (rulefn(peg, userdata)) {
		peg->pos = pos;
		MATCHER_POP();
		return true;
	} else {
		peg->pos = pos;
		MATCHER_RETURN(false);
	}
}

bool
peg_match_next_char(struct PEG *peg, const char *rule, uint32_t *cout)
{
	MATCHER_INIT();
	if (peg->pos >= peg->len) {
		MATCHER_RETURN(false);
	}

	uint32_t c;
	size_t clen = grapheme_decode_utf8(peg->buf + peg->pos, peg->len - peg->pos, &c);
	if (clen == GRAPHEME_INVALID_CODEPOINT) {
		MATCHER_RETURN(false);
	}
	peg->pos += clen;
	if (cout) {
		*cout = c;
	}
	MATCHER_RETURN(true);
}


bool
peg_match_range(struct PEG *peg, const char *rule, uint32_t a, uint32_t b)
{
	MATCHER_INIT();
	if (peg->pos >= peg->len) {
		MATCHER_RETURN(false);
	}

	uint32_t c;
	size_t clen = grapheme_decode_utf8(peg->buf + peg->pos, peg->len - peg->pos, &c);
	if (clen == GRAPHEME_INVALID_CODEPOINT || a > c || c > b) {
		MATCHER_RETURN(false);
	}
	peg->pos += clen;
	MATCHER_RETURN(true);
}

bool
peg_match_repeat(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata, uint32_t n)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	for (uint32_t i = 0; i < n; i++) {
		if (!rulefn(peg, userdata)) {
			peg->pos = pos;
			MATCHER_RETURN(false);
		}
	}
	MATCHER_RETURN(true);
}

bool
peg_match_rule(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	if (rulefn(peg, userdata)) {
		MATCHER_RETURN(true);
	} else {
		peg->pos = pos;
		MATCHER_RETURN(false);
	}
}

bool
peg_match_rule_to(struct PEG *peg, const char *rule, RuleFn rulefn, void *userdata)
{
	MATCHER_INIT();
	size_t pos = peg->pos;
	bool captures_enabled = peg->captures.enabled;
	peg->captures.enabled = false;
	while (peg->pos < peg->len) {
		size_t pos_before_match = peg->pos;
		if (rulefn(peg, userdata)) {
			peg->pos = pos_before_match;
			peg->captures.enabled = captures_enabled;
			MATCHER_RETURN(true);
		}
		uint32_t c;
		size_t clen = grapheme_decode_utf8(peg->buf + peg->pos, peg->len - peg->pos, &c);
		if (clen == GRAPHEME_INVALID_CODEPOINT) {
			break;
		} else {
			peg->pos += clen;
		}
	}
	peg->pos = pos;
	peg->captures.enabled = captures_enabled;
	MATCHER_RETURN(false);
}

bool
peg_match_string(struct PEG *peg, const char *rule, const char *needle[], size_t needlelen)
{
	MATCHER_INIT();
	for (size_t i = 0; i < needlelen; i++) {
		size_t len = strlen(needle[i]);
		if ((peg->len - peg->pos) >= len &&
		    strncmp(peg->buf + peg->pos, needle[i], len) == 0) {
			peg->pos += len;
			MATCHER_RETURN(true);
		}
	}
	MATCHER_RETURN(false);
}

void
peg_line_col_at_pos(struct PEG *peg, size_t pos, size_t *line, size_t *col)
{
	*line = 1;
	*col = 1;
	size_t len = MIN(pos, peg->len);
	for (size_t i = 0; i < len;) {
		uint32_t c;
		size_t clen = grapheme_decode_utf8(peg->buf + i, peg->len - i, &c);
		if (clen == GRAPHEME_INVALID_CODEPOINT) {
			*line = 1;
			*col = i;
			return;
		}
		if (c == '\n') {
			(*line)++;
			*col = 1;
		} else {
			(*col)++;
		}
		i += clen;
	}
}

struct Array *
peg_backtrace(struct PEG *peg, struct Mempool *pool)
{
	struct Array *errors = mempool_array(pool);
	ARRAY_FOREACH(peg->errors, struct PEGError *, err) {
		if (err->rule == NULL) {
			break;
		}
		size_t line;
		size_t col;
		peg_line_col_at_pos(peg, err->pos, &line, &col);
		char *buf;
		if (!err->msg || strcmp(err->msg, "") == 0) {
			buf = str_printf(pool, "%zu:%zu: in %s", line, col, err->rule);
		} else {
			buf = str_printf(pool, "%zu:%zu: in %s: %s", line, col, err->rule, err->msg);
		}
		array_append(errors, buf);
	}
	return errors;
}

struct PEG *
peg_new(const char *const buf, size_t len)
{
	struct PEG proto = {
		.buf = buf,
		.len = len,
	};
	struct PEG *peg = xmalloc(sizeof(struct PEG));
	memcpy(peg, &proto, sizeof(*peg));

	peg->pool = mempool_new();
	mempool_take(peg->pool, peg);

	peg->errors = mempool_array(peg->pool);
	for (size_t i = 0; i < PEG_MAX_ERRORS; i++) {
		array_append(peg->errors, mempool_alloc(peg->pool, sizeof(struct PEGError)));
	}

	peg->debug = getenv("LIBIAS_PEG_DEBUG") != NULL;
	peg->rule_trace = mempool_queue(peg->pool);

	peg->captures.pool = mempool_pool(peg->pool);
	peg->captures.set = mempool_set(peg->pool, peg_capture_compare);
	peg->captures.queue = mempool_queue(peg->pool);
	peg->captures.pos = mempool_stack(peg->pool);

	return peg;
}

void
peg_reset(struct PEG *peg)
{
	peg->pos = 0;
	peg->depth = 0;
	peg->error_index = 0;

	mempool_release_all(peg->captures.pool);
	stack_truncate(peg->captures.pos);
	queue_truncate(peg->captures.queue);
	set_truncate(peg->captures.set);
	queue_truncate(peg->rule_trace);
}

void
peg_free(struct PEG *peg)
{
	if (peg == NULL) {
		return;
	}
	mempool_free(peg->pool);
}
