diff options
author | Juan J. Martinez <jjm@usebox.net> | 2022-07-18 07:45:58 +0100 |
---|---|---|
committer | Juan J. Martinez <jjm@usebox.net> | 2022-07-18 07:45:58 +0100 |
commit | 8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642 (patch) | |
tree | c53977d1284347bb1d5963ddb4dc7723c40c6e55 /interpreter | |
download | micro-lang-8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642.tar.gz micro-lang-8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642.zip |
First public release
Diffstat (limited to 'interpreter')
-rw-r--r-- | interpreter/callable.go | 227 | ||||
-rw-r--r-- | interpreter/interpreter.go | 801 | ||||
-rw-r--r-- | interpreter/interpreter_test.go | 754 |
3 files changed, 1782 insertions, 0 deletions
diff --git a/interpreter/callable.go b/interpreter/callable.go new file mode 100644 index 0000000..e9dc6aa --- /dev/null +++ b/interpreter/callable.go @@ -0,0 +1,227 @@ +package interpreter + +import ( + "fmt" + + "usebox.net/lang/ast" + "usebox.net/lang/errors" + "usebox.net/lang/tokens" +) + +type Callable interface { + Name() string + String() string + Call(interp *Interpreter, args []any, loc tokens.Location) (any, error) + Type() *ast.Type + Params() []ast.Type + Ret() *ast.Type +} + +var builtInFuncs = map[string]Callable{ + "len": builtInLen{}, + "println": builtInPrintln{}, + "panic": builtInPanic{}, +} + +func BuiltInTypes() map[string]ast.Type { + types := map[string]ast.Type{} + for k, v := range builtInFuncs { + types[k] = *v.Type() + } + return types +} + +type builtInLen struct{} + +func (n builtInLen) Name() string { + return "'len'" +} + +func (n builtInLen) Type() *ast.Type { + return ast.NewFuncType(tokens.Token{}, n.Params(), n.Ret()) +} + +func (n builtInLen) String() string { + return n.Type().String() +} + +func (n builtInLen) Params() []ast.Type { + // won't be arity or type checked + return []ast.Type{ast.TypeArray} +} + +func (n builtInLen) Ret() *ast.Type { + return &ast.TypeNumber +} + +func (n builtInLen) Call(interp *Interpreter, args []any, loc tokens.Location) (any, error) { + vals, ok := args[0].([]any) + if !ok { + // shouldn't happen + return nil, errors.NewError(errors.Unexpected, loc, "type mismatch in call to 'len'") + } + // return + panic(&PanicJump{typ: PanicJumpReturn, value: ast.Number(len(vals))}) +} + +type builtInPrintln struct{} + +func (n builtInPrintln) Name() string { + return "'println'" +} + +func (n builtInPrintln) Type() *ast.Type { + return ast.NewFuncType(tokens.Token{}, n.Params(), n.Ret()) +} + +func (n builtInPrintln) String() string { + return n.Type().String() +} + +func (n builtInPrintln) Params() []ast.Type { + // won't be arity or type checked + return nil +} + +func (n builtInPrintln) Ret() *ast.Type { + return &ast.TypeNumber +} + +func (n builtInPrintln) Call(interp *Interpreter, args []any, loc tokens.Location) (any, error) { + var count int + + for i := range args { + if args[i] == nil { + continue + } + written, err := fmt.Print(args[i]) + if err != nil { + return nil, err + } + count += written + } + fmt.Println() + count++ + + // return + panic(&PanicJump{typ: PanicJumpReturn, value: ast.Number(count)}) +} + +type builtInPanic struct{} + +func (n builtInPanic) Name() string { + return "'panic'" +} + +func (n builtInPanic) Type() *ast.Type { + return ast.NewFuncType(tokens.Token{}, n.Params(), n.Ret()) +} + +func (n builtInPanic) String() string { + return n.Type().String() +} + +func (n builtInPanic) Params() []ast.Type { + return []ast.Type{ast.TypeString} +} + +func (n builtInPanic) Ret() *ast.Type { + // no return (returns none) + return nil +} + +func (n builtInPanic) Call(interp *Interpreter, args []any, loc tokens.Location) (any, error) { + return nil, fmt.Errorf("[%s] panic: %s", loc, args[0]) +} + +type Function struct { + fun ast.Func + closure *Env +} + +func (f Function) Name() string { + return f.fun.Name.String() +} + +func (f Function) Type() *ast.Type { + return ast.NewFuncType(f.fun.Name, f.Params(), f.Ret()) +} + +func (f Function) String() string { + return f.Type().String() +} + +func (f Function) Params() []ast.Type { + params := make([]ast.Type, 0, 1) + for _, p := range f.fun.Params { + params = append(params, p.Type) + } + return params +} + +func (f Function) Ret() *ast.Type { + return f.fun.Ret +} + +func (f Function) Call(interp *Interpreter, args []any, loc tokens.Location) (result any, err error) { + pEnv := interp.env + interp.env = NewEnv(f.closure) + defer func() { + interp.env = pEnv + }() + + for n, v := range f.fun.Params { + interp.env = interp.env.Set(v.Name.Value, Var{Value: args[n], Loc: v.Name.Loc}) + } + + // tail call optimization + for { + // wrap the evaluation in a function + var tailCall *PanicJump + tailCall, result, err = func() (tailCall *PanicJump, result any, err error) { + // handle tail call + // will call this function again without setting up a new call frame + defer func() { + if r := recover(); r != nil { + if val, ok := r.(*PanicJump); ok && val.typ == PanicJumpTailCall { + tailCall = val + return + } + panic(r) + } + }() + result, err = interp.evaluate(f.fun.Body) + return nil, result, err + }() + if tailCall == nil { + break + } + + // XXX: can we optimize this? + // if the callee can be a variable expression, we probably can't + call := tailCall.value.(ast.Call) + callee, err := interp.evaluate(call.Callee) + if err != nil { + return nil, err + } + if fun, ok := callee.(Callable); !ok || fun.Name() != f.Name() { + // can't be optimized + val, err := interp.evaluate(call) + if err != nil { + return nil, err + } + panic(&PanicJump{typ: PanicJumpReturn, value: val, loc: tailCall.loc}) + } + + args, err := interp.evalArgs(f.Name(), call.Loc, f.Params(), call.Args) + if err != nil { + return nil, err + } + + for n, v := range f.fun.Params { + interp.env.Update(v.Name.Value, args[n]) + } + } + + return result, err +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go new file mode 100644 index 0000000..8d5b7a8 --- /dev/null +++ b/interpreter/interpreter.go @@ -0,0 +1,801 @@ +package interpreter + +import ( + "fmt" + "strings" + + "usebox.net/lang/ast" + "usebox.net/lang/errors" + "usebox.net/lang/tokens" +) + +type Var struct { + Value any + Loc tokens.Location +} + +type Env struct { + env map[string]Var + parent *Env +} + +func NewEnv(parent *Env) *Env { + return &Env{make(map[string]Var), parent} +} + +func (e *Env) Get(key string, local bool) (Var, bool) { + p := e + for { + if v, ok := p.env[key]; ok { + return v, true + } + + if !local && p.parent != nil { + p = p.parent + continue + } + + return Var{}, false + } +} + +func (e *Env) GetArray(key string, local bool) ([]any, bool) { + v, ok := e.Get(key, local) + if !ok { + return nil, false + } + arr, ok := v.Value.([]any) + return arr, ok +} + +// Set will create a new enviroment if this new value shadows an existing one. +func (e *Env) Set(key string, value Var) *Env { + p := e + if _, ok := e.Get(key, false); ok { + p = NewEnv(e) + } + p.env[key] = value + return p +} + +func (e *Env) Update(key string, value any) error { + p := e + for { + if v, ok := p.env[key]; ok { + v.Value = value + p.env[key] = v + return nil + } + + if p.parent != nil { + p = p.parent + continue + } + + // should not happen + return fmt.Errorf("variable not found '%s'", key) + } +} + +func (e *Env) string(pad string) string { + b := strings.Builder{} + + b.WriteString(fmt.Sprintf("%senv {\n", pad)) + for k, v := range e.env { + b.WriteString(fmt.Sprintf("%s%s = %v\n", pad+" ", k, v)) + } + if e.parent != nil { + b.WriteString(e.parent.string(pad + " ")) + } + b.WriteString(fmt.Sprintf("%s}", pad)) + + if pad != "" { + b.WriteRune('\n') + } + + return b.String() +} + +func (e Env) String() string { + return e.string("") +} + +func (e *Env) Copy() *Env { + ne := NewEnv(e.parent) + for k, v := range e.env { + val := Var{Value: v.Value, Loc: v.Loc} + // FIXME: structs + switch v := val.Value.(type) { + // update the closures to the new environment + case Function: + v.closure = ne + val.Value = v + // copy the array + case []any: + arr := make([]any, len(v)) + copy(arr, v) + val.Value = arr + } + ne.env[k] = val + } + + return ne +} + +type Struct struct { + name string + env *Env + inst bool +} + +func (s Struct) String() string { + if s.inst { + return fmt.Sprintf("<struct %s>", s.name) + } else { + return fmt.Sprintf("struct %s", s.name) + } +} + +func (s Struct) Instance() Struct { + return Struct{name: s.name, env: s.env.Copy(), inst: true} +} + +type PanicJumpType int + +const ( + PanicJumpReturn PanicJumpType = iota + PanicJumpTailCall + PanicJumpContinue + PanicJumpBreak +) + +type PanicJump struct { + typ PanicJumpType + value any + loc tokens.Location +} + +type Interpreter struct { + env *Env + // location of the last node evaluated + last tokens.Location + fError bool +} + +func NewInterpreter() Interpreter { + global := NewEnv(nil) + for k, v := range builtInFuncs { + // not shadowing any value + global.Set(k, Var{Value: v}) + } + return Interpreter{env: global} +} + +func (i *Interpreter) Interp(expr any) (val any, err error) { + val, err = i.evaluate(expr) + if err != nil { + return nil, err + } + return val, nil +} + +func (i *Interpreter) isTrue(expr any) bool { + switch v := expr.(type) { + case ast.Number: + return v != 0 + case string: + return v != "" + case bool: + return v + case Function, Callable: + return true + default: + return false + } +} + +func (i *Interpreter) unaryNumeric(op tokens.Token, right any) (ast.Number, error) { + if vr, ok := right.(ast.Number); ok { + return vr, nil + } + + // shouldn't happen + return 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operand") +} + +func (i *Interpreter) binaryNumeric(left any, op tokens.Token, right any) (ast.Number, ast.Number, error) { + vl, okl := left.(ast.Number) + vr, okr := right.(ast.Number) + if okl && okr { + return vl, vr, nil + } + + // shouldn't happen + return 0, 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operands") +} + +func (i *Interpreter) evaluate(expr any) (any, error) { + switch v := expr.(type) { + case ast.Var: + i.last = v.Name.Loc + return i.declare(v) + case ast.Variable: + i.last = v.Name.Loc + return i.variable(v) + case ast.Assign: + return i.assignment(v) + case ast.Binary: + i.last = v.Op.Loc + return i.binaryExpr(v) + case ast.Unary: + i.last = v.Op.Loc + return i.unaryExpr(v) + case ast.Literal: + i.last = v.Value.Loc + return i.literalExpr(v) + case ast.Group: + return i.groupExpr(v) + case ast.Block: + return i.block(v) + case ast.IfElse: + return i.ifElse(v) + case ast.For: + return i.forStmt(v) + case ast.ForIn: + return i.forInStmt(v) + case ast.Call: + i.last = v.Loc + return i.call(v) + case ast.Func: + i.last = v.Name.Loc + return i.fun(v) + case ast.Return: + return i.returnStmt(v) + case ast.Continue: + return i.continueStmt(v) + case ast.Break: + return i.breakStmt(v) + case ast.ExprList: + return i.exprList(v) + case ast.Struct: + return i.strct(v) + case ast.GetExpr: + return i.getExpr(v) + default: + // XXX + return nil, errors.Error{ + Code: errors.Unimplemented, + Message: fmt.Sprintf("evaluation unimplemented: %s", v), + } + } +} + +func (i *Interpreter) defVal(typ ast.Type) any { + switch typ.Value.Id { + case tokens.TNumber: + return ast.Number(0) + case tokens.TBool: + return false + case tokens.TString: + return "" + case tokens.TArray: + arr := make([]any, *typ.Len) + for n := range arr { + arr[n] = i.defVal(*typ.Ret) + } + return arr + case tokens.TStruct: + val, _ := i.env.Get(typ.Value.Value, false) + strct := val.Value.(Struct) + return strct.Instance() + default: + return nil + } +} + +func (i *Interpreter) declare(v ast.Var) (any, error) { + var initv any + var err error + + if v.Type.Value.Id == tokens.TArray && *v.Type.Len <= 0 { + return nil, errors.NewError(errors.InvalidLength, v.Name.Loc, "invalid array length") + } + + if v.Initv != nil { + initv, err = i.evaluate(v.Initv) + if err != nil { + return 0, err + } + + i.env = i.env.Set(v.Name.Value, Var{initv, v.Name.Loc}) + return initv, nil + } else { + // defaults + initv = i.defVal(v.Type) + value := Var{initv, v.Name.Loc} + i.env = i.env.Set(v.Name.Value, value) + return initv, nil + } +} + +func (i *Interpreter) variable(v ast.Variable) (any, error) { + if val, ok := i.env.Get(v.Name.Value, false); ok { + value := val.Value + if v.Index != nil { + index, err := i.evaluate(v.Index) + if err != nil { + return nil, err + } + idx, ok := index.(ast.Number) + if !ok || idx < 0 || int(idx) >= *v.Type.Len { + return nil, errors.NewError(errors.InvalidIndex, v.Name.Loc, "invalid index for", v.Type.String()) + } + arr, ok := i.env.GetArray(v.Name.Value, false) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "expected array") + } + value = arr[idx] + } + return value, nil + } + // shouldn't happen + return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "undefined variable", v.String()) +} + +func (i *Interpreter) assignment(v ast.Assign) (any, error) { + env := i.env + vLeft := v.Left + + if getExpr, ok := v.Left.(ast.GetExpr); ok { + obj, err := i.evaluate(getExpr.Object) + if err != nil { + return nil, err + } + strct := obj.(Struct) + env = strct.env + vLeft = getExpr.Expr + } + + left, ok := vLeft.(ast.Variable) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected variable in assignation") + } + right, err := i.evaluate(v.Right) + if err != nil { + return nil, err + } + + if left.Type.Value.Id == tokens.TArray && v.Right.Resolve().Value.Id != tokens.TArray { + if left.Index == nil { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected index in assignation") + } + index, err := i.evaluate(left.Index) + if err != nil { + return nil, err + } + idx, ok := index.(ast.Number) + if !ok || idx < 0 || int(idx) >= *left.Type.Len { + return nil, errors.NewError(errors.InvalidIndex, v.Loc, "invalid index for", left.Type.String()) + } + arr, ok := i.env.GetArray(left.Name.Value, false) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected array") + } + arr[idx] = right + right = arr + } + + err = env.Update(left.Name.Value, right) + if err != nil { + return nil, errors.NewError(errors.InvalidValue, left.Name.Loc, err.Error()) + } + return right, nil +} + +func (i *Interpreter) binaryExpr(expr ast.Binary) (any, error) { + // first lazy operators + if expr.Op.Id == tokens.And || expr.Op.Id == tokens.Or { + left, err := i.evaluate(expr.Left) + if err != nil { + return nil, err + } + if expr.Op.Id == tokens.And { + if !i.isTrue(left) { + return false, nil + } + } else { + if i.isTrue(left) { + return true, nil + } + } + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + return i.isTrue(right), nil + } + + left, err := i.evaluate(expr.Left) + if err != nil { + return nil, err + } + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + + // equality compares two values as long as they are the same type + switch expr.Op.Id { + case tokens.Ne: + return right != left, nil + case tokens.Eq: + return right == left, nil + } + + vl, vr, err := i.binaryNumeric(left, expr.Op, right) + if err != nil { + return 0, err + } + + switch expr.Op.Id { + case tokens.Sub: + return vl - vr, nil + case tokens.Add: + return vl + vr, nil + case tokens.Mul: + return vl * vr, nil + case tokens.Div: + if vr == 0 { + return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") + } + return vl / vr, nil + case tokens.Mod: + if vr == 0 { + return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") + } + return vl % vr, nil + case tokens.BitAnd: + return vl & vr, nil + case tokens.BitShl: + return vl << vr, nil + case tokens.BitShr: + return vl >> vr, nil + case tokens.BitOr: + return vl | vr, nil + case tokens.BitXor: + return vl ^ vr, nil + case tokens.Lt: + return vl < vr, nil + case tokens.Le: + return vl <= vr, nil + case tokens.Gt: + return vl > vr, nil + case tokens.Ge: + return vl >= vr, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) + } + +} + +func (i *Interpreter) unaryExpr(expr ast.Unary) (any, error) { + if expr.Op.Id == tokens.TestE { + i.fError = false + } + + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + + switch expr.Op.Id { + case tokens.TestE: + return i.fError, nil + case tokens.Sub: + vr, err := i.unaryNumeric(expr.Op, right) + if err != nil { + return nil, err + } + return -vr, nil + case tokens.Neg: + vr, err := i.unaryNumeric(expr.Op, right) + if err != nil { + return nil, err + } + return ^vr, nil + case tokens.Not: + vr := i.isTrue(right) + return !vr, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) + } +} + +func (i *Interpreter) literalExpr(expr ast.Literal) (any, error) { + switch expr.Value.Id { + case tokens.Number: + return expr.Numval, nil + case tokens.String: + return expr.Value.Value, nil + case tokens.True: + return true, nil + case tokens.False: + return false, nil + case tokens.None: + return nil, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Value.Loc, "unimplemented type", expr.Value.String()) + } +} + +func (i *Interpreter) groupExpr(expr ast.Group) (any, error) { + return i.evaluate(expr.Expr) +} + +func (i *Interpreter) blockNoEnv(block ast.Block) (any, error) { + var v any + var err error + for _, expr := range block.Stmts { + v, err = i.evaluate(expr) + if err != nil { + return nil, err + } + } + return v, nil +} + +func (i *Interpreter) block(block ast.Block) (any, error) { + pEnv := i.env + i.env = NewEnv(pEnv) + defer func() { + i.env = pEnv + }() + + return i.blockNoEnv(block) +} + +func (i *Interpreter) ifElse(ifElse ast.IfElse) (any, error) { + cond, err := i.evaluate(ifElse.Cond) + if err != nil { + return nil, err + } + if i.isTrue(cond) { + return i.evaluate(ifElse.True) + } else { + return i.evaluate(ifElse.False) + } +} + +func (i *Interpreter) forEval(fcond func() (bool, error), stmts ast.Expr) (any, error) { + var last any + + for { + end, result, err := func() (end bool, result any, err error) { + // handle "interruptions" + defer func() { + if r := recover(); r != nil { + if val, ok := r.(*PanicJump); ok { + if val.typ == PanicJumpContinue { + end = false + return + } + if val.typ == PanicJumpBreak { + end = true + result = last + err = nil + return + } + } + panic(r) + } + }() + for { + if fcond != nil { + cond, err := fcond() + if err != nil { + return true, nil, err + } + if !cond { + return true, last, nil + } + } + + last, err = i.evaluate(stmts) + if err != nil { + return true, nil, err + } + } + }() + if end { + return result, err + } + } +} + +func (i *Interpreter) forStmt(forStmt ast.For) (any, error) { + if forStmt.Cond == nil { + return i.forEval(nil, forStmt.Stmts) + } else { + return i.forEval(func() (bool, error) { + cond, err := i.evaluate(forStmt.Cond) + if err != nil { + return false, err + } + return i.isTrue(cond), nil + }, forStmt.Stmts) + } +} + +func (i *Interpreter) forInStmt(forInStmt ast.ForIn) (any, error) { + expr, err := i.evaluate(forInStmt.Expr) + if err != nil { + return nil, err + } + + index := 0 + arr := expr.([]any) + + pEnv := i.env + i.env = NewEnv(pEnv) + defer func() { + i.env = pEnv + }() + i.env = i.env.Set(forInStmt.Name.Value, Var{Loc: forInStmt.Name.Loc}) + + return i.forEval(func() (bool, error) { + if index == len(arr) { + return false, nil + } + i.env.Update(forInStmt.Name.Value, arr[index]) + index++ + return true, nil + }, forInStmt.Stmts) +} + +func (i *Interpreter) evalArgs(name string, loc tokens.Location, params []ast.Type, args []ast.Expr) ([]any, error) { + vals := make([]any, 0, 16) + for _, a := range args { + arg, err := i.evaluate(a) + if err != nil { + return nil, err + } + vals = append(vals, arg) + } + + return vals, nil +} + +func (i *Interpreter) call(call ast.Call) (result any, err error) { + callee, err := i.evaluate(call.Callee) + if err != nil { + return nil, err + } + + fun, ok := callee.(Callable) + if !ok { + return nil, errors.NewError(errors.NotCallable, call.Loc, "value is not callable") + } + + args, err := i.evalArgs(fun.Name(), call.Loc, fun.Params(), call.Args) + if err != nil { + return nil, err + } + + // handle return via panic call + defer func() { + if r := recover(); r != nil { + if val, ok := r.(*PanicJump); ok && val.typ == PanicJumpReturn { + result = val.value + err = nil + } else { + // won't be handled here + panic(r) + } + } + }() + + _, err = fun.Call(i, args, call.Loc) + if err != nil { + e := errors.NewError(errors.CallError, call.Loc, "error calling", fun.Name()).(errors.Error) + e.Err = err + return nil, e + } + + // clear return if there's no return value + if fun.Ret() == nil { + return nil, nil + } + + return nil, errors.NewError(errors.NoReturn, i.last, "no return value in", fun.Name(), "expected", fun.Ret().String()) +} + +func (i *Interpreter) fun(v ast.Func) (any, error) { + callable := Function{fun: v, closure: i.env} + i.env = i.env.Set(v.Name.Value, Var{callable, v.Name.Loc}) + return nil, nil +} + +func (i *Interpreter) strct(v ast.Struct) (any, error) { + pEnv := i.env + sEnv := NewEnv(pEnv) + i.env = sEnv + defer func() { + i.env = pEnv + strct := Struct{name: v.Name.Value, env: sEnv} + i.env = i.env.Set(v.Name.Value, Var{Value: strct, Loc: v.Name.Loc}) + }() + + _, err := i.blockNoEnv(v.Body) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (i *Interpreter) getExpr(v ast.GetExpr) (any, error) { + obj, err := i.evaluate(v.Object) + if err != nil { + return nil, err + } + + if val, ok := obj.(Struct); ok { + name := v.Expr.(ast.Variable) + if name.Index == nil { + expr, _ := val.env.Get(name.Name.Value, true) + return expr.Value, nil + } + + arr, _ := val.env.GetArray(name.Name.Value, true) + index, err := i.evaluate(name.Index) + if err != nil { + return nil, err + } + return arr[index.(ast.Number)], nil + } + + // shouldn't happen + return nil, errors.NewError(errors.Unimplemented, v.Resolve().Value.Loc, "unimplemented get-expr") +} + +func (i *Interpreter) returnStmt(v ast.Return) (any, error) { + var val any + + if v.Value != nil { + // could be a tail call we could optimize + if call, ok := v.Value.(ast.Call); ok { + i.fError = v.Error + panic(&PanicJump{typ: PanicJumpTailCall, value: call, loc: v.Loc}) + } + + var err error + val, err = i.evaluate(v.Value) + if err != nil { + return nil, err + } + } + + i.fError = v.Error + panic(&PanicJump{typ: PanicJumpReturn, value: val, loc: v.Loc}) +} + +func (i *Interpreter) continueStmt(v ast.Continue) (any, error) { + panic(&PanicJump{typ: PanicJumpContinue, loc: v.Loc}) +} + +func (i *Interpreter) breakStmt(v ast.Break) (any, error) { + panic(&PanicJump{typ: PanicJumpBreak, loc: v.Loc}) +} + +func (i *Interpreter) exprList(v ast.ExprList) (any, error) { + // XXX: should we set last? + vals := make([]any, len(v.Exprs)) + for n, expr := range v.Exprs { + val, err := i.evaluate(expr) + if err != nil { + return nil, err + } + vals[n] = val + } + return vals, nil +} diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go new file mode 100644 index 0000000..d9c63ae --- /dev/null +++ b/interpreter/interpreter_test.go @@ -0,0 +1,754 @@ +package interpreter + +import ( + "fmt" + "strings" + "testing" + + "usebox.net/lang/ast" + "usebox.net/lang/errors" + "usebox.net/lang/parser" + "usebox.net/lang/tokenizer" +) + +func run(input string) (any, error) { + tr := tokenizer.NewTokenizer("-", strings.NewReader(input)) + toks, err := tr.Scan() + if err != nil { + return nil, err + } + p := parser.NewParser(BuiltInTypes()) + tree, err := p.Parse(toks) + if err != nil { + errs := err.(errors.ErrorList) + return nil, errs.Errors[0] + } + + i := NewInterpreter() + var val any + for _, expr := range tree { + if val, err = i.Interp(expr); err != nil { + return nil, err + } + } + return val, nil +} + +func expectNumber(t *testing.T, input string, n ast.Number) { + val, err := run(input) + if err != nil { + t.Errorf("Unexpected error %s (input: %s)", err, input) + } + i, ok := val.(ast.Number) + if !ok { + t.Errorf("Number expected, got %T instead (input: %s)", val, input) + } else if i != n { + t.Errorf("%d expected, got %d instead (input: %s)", n, i, input) + } +} + +func expectBool(t *testing.T, input string, b bool) { + val, err := run(input) + if err != nil { + t.Errorf("Unexpected error %s (input: %s)", err, input) + } + i, ok := val.(bool) + if !ok { + t.Errorf("bool expected, got %T instead (input: %s)", val, input) + } else if i != b { + t.Errorf("%v expected, got %v instead (input: %s)", b, i, input) + } +} + +func expectString(t *testing.T, input string, s string) { + val, err := run(input) + if err != nil { + t.Errorf("unexpected error %s (input: %s)", err, input) + } + i, ok := val.(string) + if !ok { + t.Errorf("string expected, got %T instead (input: %s)", val, input) + } else if i != s { + t.Errorf("%s expected, got %s instead (input: %s)", s, i, input) + } +} + +func expectError(t *testing.T, input string, code errors.ErrCode) { + _, err := run(input) + if err == nil { + t.Errorf("expected error and didn't happen (input: %s)", input) + } + if e, ok := err.(errors.Error); !ok { + t.Errorf("error not an Error (input: %s): %s", input, err) + } else { + if e.Code != code { + t.Errorf("error %d expected, got %d -> %s (input: %s)", code, e.Code, e, input) + } + } +} + +func TestNumExpr(t *testing.T) { + expectNumber(t, "1;", 1) + expectNumber(t, "123;", 123) + expectNumber(t, "0x123;", 0x123) + expectNumber(t, "0b101;", 0b101) + expectNumber(t, "1 + 1;", 2) + expectNumber(t, "2 + 3 * 4;", 14) + expectNumber(t, "(5 - (3 - 1)) + -1;", 2) +} + +func TestInvalidOp(t *testing.T) { + // these operators are only valid for numbers + + expectError(t, "true + \"a\";", errors.InvalidOperation) + expectError(t, "\"a\" + false;", errors.InvalidOperation) + expectError(t, "\"a\" + \"a\";", errors.InvalidOperation) + expectError(t, "-false;", errors.InvalidOperation) + expectError(t, "-true;", errors.InvalidOperation) + expectError(t, "~false;", errors.InvalidOperation) + expectError(t, "~true;", errors.InvalidOperation) + expectError(t, "-\"a\";", errors.InvalidOperation) + expectError(t, "~\"a\";", errors.InvalidOperation) + + for _, op := range []string{ + "-", "+", "*", "/", "%", "&", "<<", ">>", "|", "^", "<", ">", "<=", ">=", + } { + expectError(t, fmt.Sprintf("1 %s true;", op), errors.InvalidOperation) + expectError(t, fmt.Sprintf("true %s 1;", op), errors.InvalidOperation) + expectError(t, fmt.Sprintf("1 %s \"a\";", op), errors.InvalidOperation) + expectError(t, fmt.Sprintf("\"a\" %s 1;", op), errors.InvalidOperation) + } + + // these are invalid with numbers too + expectError(t, "1 / 0;", errors.InvalidOperation) + expectError(t, "1 % 0;", errors.InvalidOperation) +} + +func TestAssign(t *testing.T) { + expectNumber(t, "var a number; a = 1;", 1) + expectString(t, "var a string = \"v\"; var b string = a; b;", "v") + expectBool(t, "var a number = 1; var b number = 2; var c number = 3; a = b = c; a == b && a== 3;", true) + expectBool(t, "var a number; var b number = a = 1; a == b && a == 1;", true) + + // local + expectNumber(t, `{ + var a number = 1; + if a != 1 { + panic("1 expected"); + } + a = 2; + if a != 2 { + panic("2 expected"); + } + println(a = 3); + a; + }`, 3) + + // scopes and shadowing + expectNumber(t, ` + var a number = 1; + if a != 1 { + panic("1 expected"); + } + { + var a number = 2; + if a != 2 { + panic("2 expected"); + } + } + a;`, 1) + + expectError(t, "a = 1;", errors.UndefinedIdent) + + // syntax error; perhaps should be in the parser + if _, err := run("var a number; (a) = 1;"); err == nil { + t.Error("error expected by didn't happen") + } + if _, err := run("var a number; a + a = 1;"); err == nil { + t.Error("error expected by didn't happen") + } +} + +func TestPrintln(t *testing.T) { + expectNumber(t, "println();", 1) + expectNumber(t, "println(1);", 2) + expectNumber(t, "println(true);", 5) + expectNumber(t, "println(\"hello\");", 6) + expectNumber(t, "println(1, \"one\");", 5) +} + +func TestEquality(t *testing.T) { + expectBool(t, "true == true;", true) + expectBool(t, "true == false;", false) + expectBool(t, "false == true;", false) + expectBool(t, "false == false;", true) + + expectBool(t, "1 == 1;", true) + expectBool(t, "0 != 1;", true) + expectBool(t, "\"a\" == \"a\";", true) + expectBool(t, "\"a\" != \"b\";", true) + + expectBool(t, "true != true;", false) + expectBool(t, "true != false;", true) + expectBool(t, "false != true;", true) + expectBool(t, "false != false;", false) + + expectError(t, "true == 1;", errors.TypeMismatch) + expectError(t, "true == \"true\";", errors.TypeMismatch) + expectError(t, "false == 0;", errors.TypeMismatch) + expectError(t, "false == \"\";", errors.TypeMismatch) + expectError(t, "false == \"false\";", errors.TypeMismatch) +} + +func TestNot(t *testing.T) { + expectBool(t, "!true;", false) + expectBool(t, "!false;", true) + + expectBool(t, "!0;", true) + expectBool(t, "!1;", false) + expectBool(t, "!\"\";", true) + expectBool(t, "!\"a\";", false) + + expectBool(t, "!!true;", true) +} + +func TestLogical(t *testing.T) { + expectBool(t, "true && true;", true) + expectBool(t, "true && false;", false) + expectBool(t, "false && true;", false) + + expectBool(t, "true || true;", true) + expectBool(t, "true || false;", true) + expectBool(t, "false || true;", true) + + expectBool(t, "1 == 1 && 1 != 2;", true) + expectBool(t, "1 != 1 || 1 != 2;", true) + + expectBool(t, "1 & 1 && 0 | 1;", true) + expectBool(t, "1 & 1 && 2 & 1;", false) + expectBool(t, "1 & 1 || 0 | 1;", true) + expectBool(t, "2 & 1 || 1 & 2;", false) +} + +func TestIfElse(t *testing.T) { + expectBool(t, "if 1 == 1 { true; }", true) + + val, err := run("if 1 != 1 { true; }") + if err != nil { + t.Errorf("unexpected error %s", err) + } + if val != nil { + t.Errorf("none expected, %T found", val) + } + + expectBool(t, "if 1 != 1 { false; } else { true; }", true) + expectBool(t, "if 1 == 1 { false; } else { true; }", false) + + expectBool(t, "if 1 != 1 { false; } else { if 1 == 1 { true; }}", true) + + val, err = run("if 1 != 1 { true; } else { if 1 != 1 { false; }}") + if err != nil { + t.Errorf("unexpected error %s", err) + } + if val != nil { + t.Errorf("none expected, %T found", val) + } +} + +func TestPanic(t *testing.T) { + _, err := run("panic(\"error\");") + if err == nil { + t.Error("error expected and didn't happen") + } +} + +func TestFunc(t *testing.T) { + expectString(t, ` + def wrap() func () string { + def inFn() string { + return "output"; + } + return inFn; + } + var fn func () string = wrap(); + fn(); + `, "output") + + expectNumber(t, ` + def makeCounter() func () number { + var i number; + def count() number { + i = i + 1; + return i; + } + return count; + } + makeCounter()(); + `, 1) + + expectBool(t, ` + // why not? + def loop(from number, to number, fn func (number)) { + if from > to { + return; + } + fn(from); + // return is needed to trigger the tail-call optimization + return loop(from + 1, to, fn); + } + + def wilsonPrime(n number) bool { + var acc number = 1; + def fact(i number) { + acc = (acc * i) % n; + } + loop(2, n - 1, fact); + return acc == n - 1; + } + + wilsonPrime(1789); + `, true) + + expectNumber(t, ` + var a number = 1; + { + var c number; + def add() { + c = c + a; + println(c, " ", a); + } + add(); // c is 1 + var a string; + add(); // c is 2 + println(c); + c; + } + `, 2) + + expectBool(t, ` + var a bool; + def fn() func () { + def inFn() { + println("works"); + a = true; + } + return inFn; + } + fn()(); + a;`, true) + + expectBool(t, ` + var a bool; + def fn(a bool) bool { + return a; + } + fn(true);`, true) + + expectString(t, ` + def fn() { } + { + def fn() { } + } + "OK";`, "OK") + +} + +func TestFuncError(t *testing.T) { + expectString(t, ` + var fn func () bool; + "OK";`, "OK") + + expectError(t, ` + var fn func () bool; + fn();`, errors.NotCallable) + + expectError(t, ` + def a() bool { return true; } + var b func () number = a; + `, errors.TypeMismatch) + + expectError(t, ` + // return with no function + return false; + `, errors.InvalidOperation) + + expectError(t, ` + def fn() bool { return 0; } + fn(); + `, errors.TypeMismatch) + + expectError(t, ` + def fn() bool { } + fn(); + `, errors.NoReturn) + + expectError(t, "true();", errors.NotCallable) + expectError(t, "\"bad\"();", errors.NotCallable) + expectError(t, "10();", errors.NotCallable) + expectError(t, "true(1);", errors.NotCallable) + expectError(t, "\"bad\"(1);", errors.NotCallable) + expectError(t, "10(1);", errors.NotCallable) + + expectError(t, ` + var fn func() number = println; + fn("hello");`, errors.CallError) +} + +func TestFor(t *testing.T) { + expectString(t, ` + for false { + panic("oops"); + } + "OK"; + `, "OK") + + // the value emited by the loops is the last + // evaluated value inside the loop + expectNumber(t, ` + var i number; + for i < 10 { + i = i + 1; + } + `, 10) + + expectNumber(t, ` + def mkGen() func () number { + var n number; + def gen() number { + n = n + 1; + return n; + } + return gen; + } + var c func() number = mkGen(); + var acc number; + for c() < 10 { + acc = acc + 1; + } + acc; + `, 9) + + expectNumber(t, ` + var acc number; + var j number; + for j < 10 { + var i number; + for i < 10 { + i = i + 1; + acc = acc + 1; + } + j = j + 1; + } + acc; + `, 100) + + expectError(t, ` + var a number; + for { + a = 1 / a; + } + `, errors.InvalidOperation) +} + +func TestForIn(t *testing.T) { + // the value emited by the loops is the last + // evaluated value inside the loop + expectNumber(t, ` + var arr [3]number = [1, 2, 10]; + for i in arr { + i; + } + `, 10) + + // shadow the array variable + expectNumber(t, ` + var a [3]number = [1, 2, 10]; + for a in a { + println(a); + a; + } + `, 10) + + expectNumber(t, ` + var acc number; + var j [10]number; + for i in j { + for i in j { + acc = acc + 1; + } + } + acc; + `, 100) +} + +func TestContinue(t *testing.T) { + expectString(t, ` + var a number; + for a < 10 { + a = a + 1; + if a & 1 { + continue; + panic("oops"); + } + println(a); + } + "OK"; + `, "OK") + + expectString(t, ` + var arr [5]number = [1, 2, 3, 4, 5]; + for a in arr { + if a & 1 { + continue; + panic("oops"); + } + println(a); + } + "OK"; + `, "OK") + + expectNumber(t, ` + var acc number; + var a number; + for a < 10 { + var b number; + for b < 10 { + b = b + 1; + if b & 1 { + continue; + panic("oops"); + } + acc = acc + b; + } + a = a + 1; + } + acc; + `, (2+4+6+8+10)*10) +} + +func TestBreak(t *testing.T) { + expectString(t, ` + for { + break; + panic("oops"); + } + "OK"; + `, "OK") + + expectString(t, ` + var a [10] number; + for i in a { + break; + panic("oops"); + } + "OK"; + `, "OK") + + expectNumber(t, ` + var a number; + for { + if a == 20 { + break; + } + a = a + 1; + for true { + a = a + 1; + break; + a = a + 1; + } + } + a; + `, 20) +} + +func TestArray(t *testing.T) { + expectError(t, "var a [0]number;", errors.InvalidLength) + expectError(t, "var a [1]number; a[-1];", errors.InvalidIndex) + expectError(t, "var a [1]number; a[2];", errors.InvalidIndex) + expectError(t, "var a [1]number; a[-1] = 1;", errors.InvalidIndex) + expectError(t, "var a [1]number; a[2] = 1;", errors.InvalidIndex) + expectError(t, "var a [5][5]number;", errors.InvalidType) + expectError(t, "var a number; len(a);", errors.TypeMismatch) + + expectNumber(t, ` + var arr [5]number = [1, 2, 3, 4, 5]; + + var i number = 1; + for i < len(arr) { + arr[0] = arr[0] + arr[i]; + i = i + 1; + } + arr[0]; + `, 1+2+3+4+5) + + expectString(t, ` + var arr [3]string = ["one", "two", "three"]; + var arr2 [3]string; + + var i number; + for i < len(arr) { + arr2[i] = arr[i]; + i = i + 1; + } + i = 0; + for i < len(arr) { + if arr2[i] != arr[i] { + panic("failed"); + } + i = i + 1; + } + "OK"; + `, "OK") + + expectBool(t, ` + var a [5]bool; + def fn(v [5]bool) { + v[2] = true; + } + fn(a); + println(a); + a[2]; + `, true) + + expectNumber(t, ` + const a [3]number = [1, 2, 3]; + const b [3]number = [3, 2, 1]; + const c [3]number = a; + c = b; + c[0]; + `, 3) + + expectNumber(t, ` + var acc number; + + def a() { + println("this is a"); + acc = acc + 10; + } + + def b() { + println("this is b"); + acc = acc + 10; + } + + var arr [2]func () = [a, b]; + + for i in arr { + i(); + } + acc;`, 20) +} + +func TestErrorTag(t *testing.T) { + expectString(t, ` + def fn() { return !?; } + if ? fn() { + "KO"; + } + `, "KO") + expectString(t, ` + def fn() { return; } + if ? fn() { + "KO"; + } else { + "OK"; + } + `, "OK") + expectString(t, ` + def fn() number { return !? 0; } + if ? fn() { + "KO"; + } + `, "KO") + expectString(t, ` + def fn() number { return 0; } + if !(? fn()) { + "OK"; + } + `, "OK") + expectNumber(t, ` + def fn() number { return !? 99; } + var n number; + if ? (n = fn()) { + n; + } + `, 99) + expectString(t, ` + def fn() { return !?; } + fn(); + // ? resets the error flag before evaluating the right side + if ? true { + "KO"; + } else { + "OK"; + } + `, "OK") +} + +func TestStruct(t *testing.T) { + expectString(t, ` + def A { var p string = "OK"; } + var a A; + a.p; + `, "OK") + expectNumber(t, ` + def A { var p number = 10; } + var a A; + a.p; + `, 10) + expectNumber(t, ` + def A { var p number; } + var a A; + a.p = 10; + a.p; + `, 10) + expectNumber(t, ` + def A { var p number = 10; } + var a A; + a.p = 100; + var b A; + b.p; + `, 10) + expectNumber(t, ` + def A { var p number = 10; } + var a [3]A; + a[0].p + a[1].p + a[2].p; + `, 30) + expectNumber(t, ` + def A { + var arr [3]number = [1, 2, 3]; + + def p(index number) number { + return arr[index]; + } + } + var a A; + a.p(0) + a.p(1) + a.p(2); + `, 6) + expectNumber(t, ` + def A { + var p number; + } + def B { + var arr [3]A; + var i number; + + def new() A { + var r A = arr[i]; + r.p = i + 1; + i = (i + 1) % 3; + return r; + } + } + var b B; + b.new().p + b.new().p + b.new().p; + // new should have returned a reference + b.arr[0].p + b.arr[1].p + b.arr[2].p; + `, 6) + + expectError(t, ` + def A { + var a A; + } + `, errors.RecursiveStruct) +} |