/*
 * TR8 Assembler
 * Copyright (C) 2023 by Juan J. Martinez <jjm@usebox.net>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <strings.h>
#include <ctype.h>

#include "stb_image.h"

#define MAX_LINE 1024
#define MAX_ID 0x40
#define MAX_LABELS 0x200
#define MAX_REFS 0x1000
#define MAX_DEFS 0x200

#define FHH 2
#define FHL 1
#define FL  32

typedef struct
{
    const char *filename;
    uint16_t line;
} Location;

typedef struct
{
    uint32_t id;
    uint16_t addr;
    Location loc;
} Label;

typedef enum
{
    ModNone = 0,
    ModLow,
    ModHigh
} Mod;

typedef struct
{
    uint32_t id;
    char label[MAX_ID + 1];
    Mod mod;
    uint16_t mask;
    uint16_t addr;
    Location loc;
} Reference;

typedef struct
{
    uint32_t id;
    char value[MAX_ID + 1];
    Location loc;
} Define;

typedef struct
{
    uint8_t out[UINT16_MAX + 1];
    size_t size;

    uint16_t addr;
    Location loc;

    Label labels[MAX_LABELS];
    uint16_t lcnt;

    Reference refs[MAX_REFS];
    uint16_t rcnt;

    Define defs[MAX_DEFS];
    uint16_t dcnt;
} As;

typedef struct
{
    char id[9];
    uint8_t (*parse)(As *, char **);
} InstParse;

/* used to map RGB values to palette index when including a PNG image */
static uint8_t palette[16][3] =
{
    { 0x00, 0x00, 0x00 },
    { 0x00, 0x00, 0xaa },
    { 0x00, 0xaa, 0x00 },
    { 0x00, 0xaa, 0xaa },
    { 0xaa, 0x00, 0x00 },
    { 0xaa, 0x00, 0xaa },
    { 0xaa, 0x55, 0x00 },
    { 0xaa, 0xaa, 0xaa },
    { 0x55, 0x55, 0x55 },
    { 0x55, 0x55, 0xff },
    { 0x55, 0xff, 0x55 },
    { 0x55, 0xff, 0xff },
    { 0xff, 0x55, 0x55 },
    { 0xff, 0x55, 0xff },
    { 0xff, 0xff, 0x55 },
    { 0xff, 0xff, 0xff },
};

static uint8_t error_l(const char *msg, Location *loc, const char *reason)
{
    fprintf(stderr, "%s (%s:%d): %s\n", msg, loc->filename, loc->line, reason);
    return 0;
}

static uint8_t error(const char *msg, const char *reason)
{
    if (reason == NULL)
        fprintf(stderr, "error: %s\n", msg);
    else
        fprintf(stderr, "%s: %s\n", msg, reason);
    return 0;
}

static uint32_t str_hash(char const *str)
{
    uint32_t m = 0x5bd1e995u;
    uint32_t h = 0x31313137u;
    size_t length = strlen(str);

    while (length >= 4)
    {
        uint32_t k = *(uint32_t *)str;
        k *= m;
        k ^= k >> 24;
        k *= m;
        h *= m;
        h ^= k;
        str += 4;
        length -= 4;
    }
    switch (length) {
        case 3:
            h ^= str[ 2 ] << 16;
        case 2:
            h ^= str[ 1 ] << 8;
        case 1:
            h ^= str[ 0 ];
            h *= m;
    }
    h ^= h >> 13;
    h *= m;
    h ^= h >> 15;

    return h;
}

static char * skip_whitespace(char *c)
{
    while (*c && isspace(*c))
        c++;
    return c;
}

static uint8_t next_string(As *as, char **c, char *word, uint8_t *wlen)
{
    *c = skip_whitespace(*c);

    *wlen = 0;
    if (**c != '"')
        return 0;
    (*c)++;

    while (**c && **c != '"')
    {
        word[(*wlen)++] = **c;
        (*c)++;
        if (*wlen == MAX_ID)
        {
            word[*wlen - 1] = 0;
            *wlen = 0;
            return error_l("String is too long", &as->loc, word);
        }
    }
    word[*wlen] = 0;

    /* the closing quote */
    (*c)++;

    return 1;
}

static uint8_t isspecial(char c)
{
    return c == '$' || c == '_' || c == '.' || c == '#' || c == '<' || c == '>';
}

static char * find_define(As *as, char *word, char *found)
{
    uint8_t i;
    uint32_t id = str_hash(word);

    for (i = 0; i < as->dcnt; i++)
        if (id == as->defs[i].id)
            return find_define(as, as->defs[i].value, as->defs[i].value);

    return found;
}

static uint8_t new_define(As *as, char *id, char *value)
{
    if (find_define(as, id, NULL))
        return error_l("Id redefined", &as->loc, id);

    value = find_define(as, value, value);

    if (!strcmp(id, value))
        return error_l("Recursive definition", &as->loc, id);

    as->defs[as->dcnt].id = str_hash(id);
    strcpy(as->defs[as->dcnt].value, value);

    as->defs[as->dcnt].loc.filename = as->loc.filename;
    as->defs[as->dcnt++].loc.line = as->loc.line;

    if (as->dcnt == MAX_DEFS)
        return error("Too many definitions", NULL);

    return 1;
}

static uint8_t next_word(As *as, char **c, char *word, uint8_t *wlen)
{
    *c = skip_whitespace(*c);

    *wlen = 0;
    while (**c && (isalnum(**c) || isspecial(**c)))
    {
        word[(*wlen)++] = **c;
        (*c)++;
        if (*wlen == MAX_ID)
        {
            word[*wlen - 1] = 0;
            *wlen = 0;
            return error_l("Invalid input (too long)", &as->loc, word);
        }
    }
    word[*wlen] = 0;

    return 1;
}

static uint8_t next_imm(As *as, char *word, uint16_t *n)
{
    Mod mod = ModNone;
    char *def;

    *n = 0;

    if (*word == '<')
    {
        mod = ModLow;
        word++;
    }
    else if (*word == '>')
    {
        mod = ModHigh;
        word++;
    }

    def = find_define(as, word, NULL);
    if (def)
        word = def;

    if (*word == '0')
    {
        /* hex */
        if (word[1] == 'x')
        {
            word += 2;
            while (*word)
            {
                if (*word >= '0' && *word <= '9')
                    *n = (*n) * 16 + (*word - '0');
                else if (*word >= 'a' && *word <= 'f')
                    *n = (*n) * 16 + 10 + (*word - 'a');
                else if (*word >= 'A' && *word <= 'F')
                    *n = (*n) * 16 + 10 + (*word - 'A');
                else
                    break;

                word++;
            }
            goto done;
        }
        /* bin */
        else if (word[1] == 'b')
        {
            word += 2;
            while (*word)
            {
                if (*word == '0' || *word == '1')
                    *n = (*n) * 2 + (*word - '0');
                else
                    break;

                word++;
            }
            goto done;
        }
    }

    /* dec */
    while (*word)
    {
        if (*word >= '0' && *word <= '9')
            *n = (*n) * 10 + (*word - '0');
        else
            break;

        word++;
    }

done:
    if (*word == 0)
    {
        if (mod == ModLow)
            *n &= 0xff;
        else if (mod == ModHigh)
            *n >>= 8;

        return 1;
    }

    return 0;
}

static uint8_t parse_register(char *word)
{
    if ((*word == 'a' || *word == 'A') && word[1] == 0)
        return 0;
    if ((*word == 'b' || *word == 'B') && word[1] == 0)
        return 1;
    if ((*word == 'x' || *word == 'X') && word[1] == 0)
        return 2;
    if ((*word == 'y' || *word == 'Y') && word[1] == 0)
        return 3;

    return 0xff;
}

static Label * find_label(As *as, char *word)
{
    uint8_t i;
    uint32_t id = str_hash(word);

    for (i = 0; i < as->lcnt; i++)
        if (id == as->labels[i].id)
            return &as->labels[i];

    return NULL;
}

static uint8_t new_label(As *as, char *word)
{
    if (find_label(as, word))
        return error_l("Label redefined", &as->loc, word);

    as->labels[as->lcnt].id = str_hash(word);

    as->labels[as->lcnt].loc.filename = as->loc.filename;
    as->labels[as->lcnt].loc.line = as->loc.line;
    as->labels[as->lcnt++].addr = as->addr;

    if (as->lcnt == MAX_LABELS)
        return error("Too many labels", NULL);

    return 1;
}

static uint8_t new_ref(As *as, char *word, uint16_t mask, uint16_t addr)
{
    Mod mod = ModNone;
    char *pt = word;

    if (*pt == '<')
    {
        mod = ModLow;
        pt++;
    }
    else if (*pt == '>')
    {
        mod = ModHigh;
        pt++;
    }

    if (!*pt)
        return error_l("Syntax error", &as->loc, "expected label");

    if (mask == 0xff && mod == ModNone)
        return error_l("Overflow in immediate", &as->loc, word);

    as->refs[as->rcnt].id = str_hash(pt);
    strcpy(as->refs[as->rcnt].label, pt);
    as->refs[as->rcnt].mod = mod;
    as->refs[as->rcnt].mask = mask;
    as->refs[as->rcnt].loc = as->loc;
    as->refs[as->rcnt++].addr = addr;

    if (as->rcnt == MAX_REFS)
        return error("Too many references", NULL);

    return 1;
}

static uint8_t resolve(As *as)
{
    uint16_t i;
    uint16_t v;
    Label *label;

    for (i = 0; i < as->rcnt; i++)
    {
        label = find_label(as, as->refs[i].label);
        if (!label)
            return error_l("Unresolved reference", &as->refs[i].loc, as->refs[i].label);

        v = label->addr;
        switch (as->refs[i].mod)
        {
            case ModLow:
                v &= 0xff;
                break;
            case ModHigh:
                v >>= 8;
                break;
            default:
                break;
        }
        as->out[as->refs[i].addr] = v & 0xff;
        if (as->refs[i].mask == 0xffff)
            as->out[as->refs[i].addr + 1] = v >> 8;
    }

    return 1;
}

/*
 * inst: instruction opcode
 * p1: r1
 * p2: either FHx or r2
 * p3: either FL, r3 or immediate
 */
static uint8_t emit(As *as, uint8_t instr, uint8_t p1, uint8_t p2, uint8_t p3)
{
    if (as->addr + 2 > UINT16_MAX)
        return error_l("Memory overflow", &as->loc, "output is more than 65535 bytes");

    as->out[as->addr++] = p3;
    as->out[as->addr++] = (instr << 4) | (p1 << 2) | p2;
    if (as->addr > as->size)
        as->size = as->addr;

    return 1;
}

static uint8_t emit_imm(As *as, uint16_t imm)
{
    if (as->addr + 2 > UINT16_MAX)
        return error_l("Memory overflow", &as->loc, "output is more than 65535 bytes");

    as->out[as->addr++] = imm & 0xff;
    as->out[as->addr++] = imm >> 8;
    if (as->addr > as->size)
        as->size = as->addr;

    return 1;
}

static uint8_t parse_org(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;

    /* .org imm */

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected immediate");

    if (!next_imm(as, word, &as->addr))
        return error_l("Syntax error", &as->loc, word);

    if (as->addr > as->size)
        as->size = as->addr;

    return 1;
}

static uint8_t parse(As *as, char *c);

static uint8_t parse_include(As *as, char **c)
{
    /* XXX: may be support longer filenames */
    char word[MAX_ID + 1];
    uint8_t wlen;
    char line[MAX_LINE];
    Location old;
    FILE *fd;
    uint8_t rc;

    /* .include "filename" */

    if (!next_string(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected string");

    fd = fopen(word, "rb");
    if (!fd)
    {
        error_l("Failed to open include", &as->loc, word);
        return 1;
    }

    old.filename = as->loc.filename;
    old.line = as->loc.line;

    as->loc.filename = word;
    as->loc.line = 0;

    while (fgets(line, MAX_LINE - 1, fd))
    {
        as->loc.line++;
        if ((rc = parse(as, line)) == 0)
            break;
    }

    fclose(fd);

    as->loc.filename = old.filename;
    as->loc.line = old.line;

    if (!rc)
        return error_l("Error in include", &as->loc, word);

    return 1;
}

static uint8_t parse_incpng(As *as, char **c)
{
    /* XXX: may be support longer filenames */
    char word[MAX_ID + 1];
    uint8_t wlen;

    /* .incpng "filename" */

    if (!next_string(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected string");

    if (strcasecmp(word + wlen - 4, ".png"))
        return error_l("Invalid input", &as->loc, "expected PNG file");

    int x, y, n;
    uint8_t rc = 1;

    unsigned char *data = stbi_load(word, &x, &y, &n, 3);
    if (!data)
        return error_l("Invalid input", &as->loc, "expected PNG file");

    if (x * y + as->addr > UINT16_MAX)
    {
        rc = error_l("Overflow in incpng", &as->loc, word);
        goto exit_parseinc;
    }

    uint16_t addr = as->addr;
    uint8_t index;

    /* map RGB values to the palette indexes;
     * any unmatched value is "transparent" with index 128 */
    for (int i = 0; i < x * y * 3; i += 3)
    {
        for (index = 0; index < 16; index++)
            if (data[i] == palette[index][0]
                    && data[i + 1] == palette[index][1]
                    && data[i + 2] == palette[index][2])
            {
                as->out[addr++] = index;
                break;
            }
        if (index == 16)
            as->out[addr++] = 128;
    }

    as->addr = addr;
    if (as->addr > as->size)
        as->size = as->addr;

exit_parseinc:
    stbi_image_free(data);

    return rc;
}

static uint8_t parse_incbin(As *as, char **c)
{
    /* XXX: may be support longer filenames */
    char word[MAX_ID + 1];
    uint8_t wlen;
    FILE *fd;
    size_t size;

    /* .incbin "filename" */

    if (!next_string(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected string");

    fd = fopen(word, "rb");
    if (!fd)
    {
        error_l("Failed to open incbin", &as->loc, word);
        return 1;
    }

    fseek(fd, 0, SEEK_END);
    size = ftell(fd);
    if (size + as->addr > UINT16_MAX)
    {
        fclose(fd);
        error_l("Overflow in incbin", &as->loc, word);
        return 1;
    }
    fseek(fd, 0, SEEK_SET);
    if (fread(as->out + as->addr, 1, size, fd) != size)
    {
        fclose(fd);
        error_l("Read error in incbin", &as->loc, word);
        return 1;
    }
    fclose(fd);

    as->addr += size;
    if (as->addr > as->size)
        as->size = as->addr;

    return 1;
}

static uint8_t parse_equ(As *as, char **c)
{
    char id[MAX_ID + 1];
    char value[MAX_ID + 1];
    uint8_t wlen;

    /* .equ label imm */

    if (!next_word(as, c, id, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected id");

    if (!next_word(as, c, value, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected value");

    return new_define(as, id, value);
}

static uint8_t parse_db(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint16_t imm = 0xff;

    /* .db imm [, imm] */

    while (1)
    {
        if (!next_word(as, c, word, &wlen))
            return 0;

        if (wlen == 0)
            return error_l("Syntax error", &as->loc, "expected immediate");

        if (next_imm(as, word, &imm))
        {
            if (imm > 0xff)
                return error_l("Overflow in immediate", &as->loc, word);
        }
        else if (!new_ref(as, word, 0xff, as->addr))
            return 0;

        as->out[as->addr++] = imm & 0xff;
        if (as->addr > as->size)
            as->size = as->addr;

        if (**c == ',')
        {
            (*c)++;
            continue;
        }
        break;
    }

    return 1;
}

static uint8_t parse_dw(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint16_t imm;

    /* .dw imm [, imm] */

    while (1)
    {
        if (!next_word(as, c, word, &wlen))
            return 0;

        if (wlen == 0)
            return error_l("Syntax error", &as->loc, "expected immediate");

        if (!next_imm(as, word, &imm)
                && !new_ref(as, word, 0xffff, as->addr))
            return 0;

        as->out[as->addr++] = imm & 0xff;
        as->out[as->addr++] = imm >> 8;
        if (as->addr > as->size)
            as->size = as->addr;

        if (**c == ',')
        {
            (*c)++;
            continue;
        }
        break;
    }

    return 1;
}

static uint8_t parse_nop(As *as, char **c)
{
    /* NOP */
    return emit(as, 0, 0, 0, 0);
}

static uint8_t parse_sif(As *as, char **c)
{
    /* SIF */
    return emit(as, 11, 0, FHL, 0);
}

static uint8_t parse_cif(As *as, char **c)
{
    /* CIF */
    return emit(as, 11, 0, 0, 0);
}

static uint8_t parse_ccf(As *as, char **c)
{
    /* CCF */
    return emit(as, 13, 0, FHH | FHL, 0);
}

static uint8_t parse_scf(As *as, char **c)
{
    /* SCF */
    return emit(as, 13, 0, FHH, 0);
}

static uint8_t parse_sof(As *as, char **c)
{
    /* SOF */
    return emit(as, 13, 0, FHL, 0);
}

static uint8_t parse_cof(As *as, char **c)
{
    /* COF */
    return emit(as, 13, 0, 0, 0);
}

static uint8_t parse_halt(As *as, char **c)
{
    /* HALT */
    return emit(as, 0, 0, FHH | FHL, 0);
}

static uint8_t parse_iret(As *as, char **c)
{
    /* IRET */
    return emit(as, 0, 0, FHL, 0);
}

static uint8_t parse_ret(As *as, char **c)
{
    /* RET */
    return emit(as, 0, 0, 0, FL);
}

static uint8_t parse_r1_r2_or_imm(As *as, char **c, uint8_t *r1, uint8_t *r2, uint16_t *imm)
{
    char word[MAX_ID + 1];
    uint8_t wlen;

    /* ? r1, ? */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    *r1 = parse_register(word);
    if (*r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    *c = skip_whitespace(*c);
    if (**c != ',')
        return error_l("Syntax error", &as->loc, "expected ,");
    (*c)++;

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register or immediate");

    /* ? r1, r2 */
    *r2 = parse_register(word);
    if (*r2 != 0xff)
        return 1;

    /* ? r1, imm */
    if (next_imm(as, word, imm))
    {
        if (*imm > 0xff)
            return error_l("Overflow in immediate", &as->loc, word);
    }
    /* ? r1, label */
    else if (!new_ref(as, word, 0xff, as->addr))
        return 0;

    return 1;
}

static uint8_t parse_r1_imm(As *as, char **c, uint8_t *r1, uint16_t *imm)
{
    char word[MAX_ID + 1];
    uint8_t wlen;

    /* ? r1, imm */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    *r1 = parse_register(word);
    if (*r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    *c = skip_whitespace(*c);
    if (**c != ',')
        return error_l("Syntax error", &as->loc, "expected ,");
    (*c)++;

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected immediate");

    if (!isdigit(*word))
        return error_l("Syntax error", &as->loc, "expected immediate");

    if (!next_imm(as, word, imm))
        return error_l("Syntax error", &as->loc, word);

    return 1;
}

static uint8_t parse_and(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* AND r1, r2 */
        return emit(as, 4, r1, FHH | FHL, (r2 << 6));

    /* AND r1, imm */
    return emit(as, 4, r1, FHH, imm);
}

static uint8_t parse_or(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* OR r1, r2 */
        return emit(as, 4, r1, FHL, (r2 << 6));

    /* OR r1, imm */
    return emit(as, 4, r1, 0,  imm);
}

static uint8_t parse_xor(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* XOR r1, r2 */
        return emit(as, 5, r1, FHH | FHL, (r2 << 6));

    /* XOR r1, imm */
    return emit(as, 5, r1, FHH,  imm);
}

static uint8_t parse_cmp(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* CMP r1, r2 */
        return emit(as, 5, r1, FHL, (r2 << 6));

    /* CMP r1, imm */
    return emit(as, 5, r1, 0,  imm);
}

static uint8_t parse_add(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* ADD r1, r2 */
        return emit(as, 6, r1, FHH | FHL, (r2 << 6));

    /* ADD r1, imm */
    return emit(as, 6, r1, FHH,  imm);
}

static uint8_t parse_sub(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1, r2;

    if (!parse_r1_r2_or_imm(as, c, &r1, &r2, &imm))
        return 0;

    if (r2 != 0xff)
        /* SUB r1, r2 */
        return emit(as, 6, r1, FHL, (r2 << 6));

    /* SUB r1, imm */
    return emit(as, 6, r1, 0,  imm);
}

static uint8_t parse_bit(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1;

    /* BIT r1, imm */
    if (!parse_r1_imm(as, c, &r1, &imm))
        return 0;

    if (imm > 7)
        return error_l("Immediate out of range", &as->loc, "expected bit (0-7)");

    return emit(as, 7, r1, FHH, imm);
}

static uint8_t parse_shl(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1;

    /* SHL r1, imm */
    if (!parse_r1_imm(as, c, &r1, &imm))
        return 0;

    if (imm > 7)
        return error_l("Immediate out of range", &as->loc, "expected bit (0-7)");

    return emit(as, 7, r1, FHL, imm);
}

static uint8_t parse_shr(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1;

    /* SHR r1, imm */
    if (!parse_r1_imm(as, c, &r1, &imm))
        return 0;

    if (imm > 7)
        return error_l("Immediate out of range", &as->loc, "expected bit (0-7)");

    return emit(as, 7, r1, 0, imm);
}

static uint8_t parse_ror(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1;

    /* ROR r1, imm */
    if (!parse_r1_imm(as, c, &r1, &imm))
        return 0;

    if (imm > 7)
        return error_l("Immediate out of range", &as->loc, "expected bit (0-7)");

    return emit(as, 8, r1, 0, imm);
}

static uint8_t parse_rol(As *as, char **c)
{
    uint16_t imm = 0xff;
    uint8_t r1;

    /* ROL r1, imm */
    if (!parse_r1_imm(as, c, &r1, &imm))
        return 0;

    if (imm > 7)
        return error_l("Immediate out of range", &as->loc, "expected bit (0-7)");

    return emit(as, 8, r1, FHH, imm);
}

static uint8_t parse_push(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1;

    /* PUSH r1 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
    {
        /* PUSH F */
        if (*word == 'f' || *word == 'F')
            return emit(as, 3, 0, FHH | FHL, 0);
        return error_l("Syntax error", &as->loc, word);
    }

    return emit(as, 3, r1, FHL, 0);
}

static uint8_t parse_port(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1, r2;

    /* PORT r1, r2 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    *c = skip_whitespace(*c);
    if (**c != ',')
        return error_l("Syntax error", &as->loc, "expected ,");
    (*c)++;

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r2 = parse_register(word);
    if (r2 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    return emit(as, 0, r1, FHH, r2 << 6);
}

static uint8_t parse_pop(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1;

    /* POP r1 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
    {
        /* POP F */
        if (*word == 'f' || *word == 'F')
            return emit(as, 3, 0, FHH, 0);
        return error_l("Syntax error", &as->loc, word);
    }

    return emit(as, 3, r1, 0, 0);
}

static uint8_t parse_xsp(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1;

    /* XSP r1 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    return emit(as, 3, r1, 0, FL);
}

static uint8_t parse_inc(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1;

    /* INC r1 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    return emit(as, 11, r1, FHH | FHL, 0);
}

static uint8_t parse_dec(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t r1;

    /* INC r1 */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    return emit(as, 11, r1, FHH, 0);
}

static uint8_t parse_bz(As *as, char **c)
{
    /* BZ */
    return emit(as, 10, 0, 0, 0);
}

static uint8_t parse_bnz(As *as, char **c)
{
    /* BNZ */
    return emit(as, 10, 0, FHH, 0);
}

static uint8_t parse_bc(As *as, char **c)
{
    /* BC */
    return emit(as, 10, 0, 0, 1);
}

static uint8_t parse_bnc(As *as, char **c)
{
    /* BNC */
    return emit(as, 10, 0, FHH, 1);
}

static uint8_t parse_bo(As *as, char **c)
{
    /* BO */
    return emit(as, 10, 0, 0, 2);
}

static uint8_t parse_bno(As *as, char **c)
{
    /* BNO */
    return emit(as, 10, 0, FHH, 2);
}

static uint8_t parse_bs(As *as, char **c)
{
    /* BS */
    return emit(as, 10, 0, 0, 3);
}

static uint8_t parse_bns(As *as, char **c)
{
    /* BNS */
    return emit(as, 10, 0, FHH, 3);
}

static uint8_t parse_bi(As *as, char **c)
{
    /* BI */
    return emit(as, 10, 0, 0, 4);
}

static uint8_t parse_bni(As *as, char **c)
{
    /* BNI */
    return emit(as, 10, 0, FHH, 4);
}

static uint8_t parse_indirect(As *as, char **c, uint8_t *r1, uint8_t *r2, uint16_t *imm)
{
    char word[MAX_ID + 1];
    uint8_t wlen;

    /* [r1:r2] */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    *r1 = parse_register(word);
    if (*r1 == 0xff)
    {
        if (imm == NULL)
            return error_l("Syntax error", &as->loc, word);

        /* [SP + imm] */
        if (!strcasecmp(word, "sp"))
        {
            *c = skip_whitespace(*c);
            if (**c != '+')
                return error_l("Syntax error", &as->loc, "expected +");
            (*c)++;

            if (!next_word(as, c, word, &wlen))
                return 0;

            if (wlen == 0)
                return error_l("Syntax error", &as->loc, "expected immediate");

            if (!next_imm(as, word, imm))
                return error_l("Syntax error", &as->loc, word);

            if (*imm > 0xff)
                return error_l("Overflow in immediate", &as->loc, word);

            *c = skip_whitespace(*c);
            if (**c != ']')
                return error_l("Syntax error", &as->loc, "expected ]");
            (*c)++;

            return 1;
        }
        else
            return error_l("Syntax error", &as->loc, word);
    }

    *c = skip_whitespace(*c);
    if (**c != ':')
        return error_l("Syntax error", &as->loc, "expected :");

    (*c)++;
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    *r2 = parse_register(word);
    if (*r2 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    *c = skip_whitespace(*c);
    if (**c != ']')
        return error_l("Syntax error", &as->loc, "expected ]");

    (*c)++;
    return 1;
}

static uint8_t parse_jmp(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint16_t imm = 0xffff;
    uint8_t r1, r2;

    *c = skip_whitespace(*c);

    /* JMP [r1:r2] */
    if (**c == '[')
    {
        (*c)++;
        if (!parse_indirect(as, c, &r1, &r2, NULL))
            return 0;

        return emit(as, 9, r1, 0, (r2 << 6) | FL);
    }

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected label or immediate");

    /* JMP imm */
    if (!next_imm(as, word, &imm)
            /* JMP label */
            && !new_ref(as, word, 0xffff, as->addr + 2))
        return 0;

    return (emit(as, 9, 0, FHH, FL) && emit_imm(as, imm));
}

static uint8_t parse_call(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint16_t imm = 0xffff;
    uint8_t r1, r2;

    *c = skip_whitespace(*c);

    /* CALL [r1:r2] */
    if (**c == '[')
    {
        (*c)++;
        if (!parse_indirect(as, c, &r1, &r2, NULL))
            return 0;

        return emit(as, 9, r1, 0, (r2 << 6));
    }

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected label or immediate");

    /* CALL imm */
    if (!next_imm(as, word, &imm)
            /* CALL label */
            && !new_ref(as, word, 0xffff, as->addr + 2))
        return 0;

    return (emit(as, 9, 0, FHH, 0) && emit_imm(as, imm));
}

static uint8_t parse_ld(As *as, char **c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint16_t imm = 0xff;

    uint8_t r1, r2, r3;

    *c = skip_whitespace(*c);
    /* LD [r1:r2], r3 */
    /* LD [sp + imm], r2 */
    if (**c == '[')
    {
        (*c)++;
        if (!parse_indirect(as, c, &r1, &r2, &imm))
            return 0;

        /* expected , */
        *c = skip_whitespace(*c);
        if (**c != ',')
            return error_l("Syntax error", &as->loc, "expected ,");

        (*c)++;
        if (!next_word(as, c, word, &wlen))
            return 0;

        if (wlen == 0)
            return error_l("Syntax error", &as->loc, "expected register");

        r3 = parse_register(word);
        if (r3 == 0xff)
            return error_l("Syntax error", &as->loc, word);

        if (r1 == 0xff)
            return emit(as, 12, r3, FHH, imm);

        return emit(as, 2, r1, r2, (r3 << 6) | FL);
    }

    /* LD r1, ? */
    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register");

    r1 = parse_register(word);
    if (r1 == 0xff)
        return error_l("Syntax error", &as->loc, word);

    *c = skip_whitespace(*c);
    if (**c != ',')
        return error_l("Syntax error", &as->loc, "expected ,");
    (*c)++;

    *c = skip_whitespace(*c);
    /* LD r1, [r2:r3] */
    /* LD r1, [sp + imm] */
    if (**c == '[')
    {
        (*c)++;
        if (!parse_indirect(as, c, &r2, &r3, &imm))
            return 0;

        if (r2 == 0xff)
            return emit(as, 12, r1, 0, imm);

        return emit(as, 2, r2, r3, r1 << 6);
    }

    if (!next_word(as, c, word, &wlen))
        return 0;

    if (wlen == 0)
        return error_l("Syntax error", &as->loc, "expected register or immediate");

    /* LD r1, r2 */
    r2 = parse_register(word);
    if (r2 != 0xff)
        return emit(as, 1, r1, FHH, (r2 << 6));

    /* LD r1, imm */
    if (next_imm(as, word, &imm))
    {
        if (imm > 0xff)
            return error_l("Overflow in immediate", &as->loc, word);
    }
    /* LD r1, label */
    else if (!new_ref(as, word, 0xff, as->addr))
        return 0;

    return emit(as, 1, r1, 0, imm);
}

static InstParse insts[] =
{
    { ".include", parse_include },
    { ".incpng", parse_incpng },
    { ".incbin", parse_incbin },
    { ".org", parse_org },
    { ".equ", parse_equ },
    { ".db", parse_db },
    { ".dw", parse_dw },
    { "halt", parse_halt },
    { "push", parse_push },
    { "port", parse_port },
    { "iret", parse_iret },
    { "call", parse_call },
    { "xsp", parse_xsp },
    { "and", parse_and },
    { "cmp", parse_cmp },
    { "add", parse_add },
    { "sub", parse_sub },
    { "bit", parse_bit },
    { "shl", parse_shl },
    { "shr", parse_shr },
    { "ror", parse_ror },
    { "rol", parse_rol },
    { "xor", parse_xor },
    { "ret", parse_ret },
    { "pop", parse_pop },
    { "nop", parse_nop },
    { "inc", parse_inc },
    { "dec", parse_dec },
    { "bnz", parse_bnz },
    { "bnc", parse_bnc },
    { "bno", parse_bno },
    { "bns", parse_bns },
    { "bni", parse_bni },
    { "jmp", parse_jmp },
    { "sif", parse_sif },
    { "cif", parse_cif },
    { "ccf", parse_ccf },
    { "scf", parse_scf },
    { "sof", parse_sof },
    { "cof", parse_cof },
    { "or", parse_or },
    { "bz", parse_bz },
    { "bc", parse_bc },
    { "bo", parse_bo },
    { "bs", parse_bs },
    { "bi", parse_bi },
    { "ld", parse_ld },
    { "", NULL },
};

static uint8_t parse(As *as, char *c)
{
    char word[MAX_ID + 1];
    uint8_t wlen;
    uint8_t i;

    if (!next_word(as, &c, word, &wlen))
        return 0;

    if (wlen == 0)
        return 1;

    /* comment */
    if (word[0] == ';')
        return 1;

    /* new label */
    if (*c == ':')
    {
        c++;
        if (!new_label(as, word))
            return 0;
    }
    else
    {
        /* instructions */
        for (i = 0; insts[i].parse; i++)
            if (!strcasecmp(insts[i].id, word))
            {
                if (!insts[i].parse(as, &c))
                    return 0;
                break;
            }

        if (!insts[i].parse)
            return error_l("Parse error", &as->loc, word);
    }

    /* EOL */
    if (!next_word(as, &c, word, &wlen))
        return 0;

    if (wlen == 0)
        return 1;

    /* or comment */
    if (word[0] == ';')
        return 1;

    return error_l("Parse error", &as->loc, word);
}

static uint8_t asm(As *as, const char *filename, FILE *in)
{
    char line[MAX_LINE];
    memset(as, 0, sizeof(As));

    as->loc.filename = filename;

    while (fgets(line, MAX_LINE - 1, in))
    {
        as->loc.line++;
        if (!parse(as, line))
            return 0;
    }

    return resolve(as);
}

#ifdef DO_MAIN

int main(int argc, char *argv[])
{
    int rc = 0;
    FILE *src, *out;
    As as;

    if (argc < 3)
    {
        error("usage", "input.asm output.tr8");
        return 1;
    }
    src = fopen(argv[1], "rt");
    if (!src)
    {
        error("Failed to open input", argv[1]);
        return 1;
    }

    if (asm(&as, argv[1], src))
    {
        out = fopen(argv[2], "wb");
        if (!out)
        {
            fclose(src);
            error("Failed to open output", argv[2]);
            return 1;
        }
        fwrite(as.out, as.size, 1, out);
        fclose(out);

        fprintf(stderr, "%s: %lu bytes, OK\n", argv[2], as.size);
    }
    else
        rc = 11;

    fclose(src);
    return rc;
}

#endif /* DO_MAIN */