From 8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642 Mon Sep 17 00:00:00 2001 From: "Juan J. Martinez" Date: Mon, 18 Jul 2022 07:45:58 +0100 Subject: First public release --- interpreter/interpreter.go | 801 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 801 insertions(+) create mode 100644 interpreter/interpreter.go (limited to 'interpreter/interpreter.go') diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go new file mode 100644 index 0000000..8d5b7a8 --- /dev/null +++ b/interpreter/interpreter.go @@ -0,0 +1,801 @@ +package interpreter + +import ( + "fmt" + "strings" + + "usebox.net/lang/ast" + "usebox.net/lang/errors" + "usebox.net/lang/tokens" +) + +type Var struct { + Value any + Loc tokens.Location +} + +type Env struct { + env map[string]Var + parent *Env +} + +func NewEnv(parent *Env) *Env { + return &Env{make(map[string]Var), parent} +} + +func (e *Env) Get(key string, local bool) (Var, bool) { + p := e + for { + if v, ok := p.env[key]; ok { + return v, true + } + + if !local && p.parent != nil { + p = p.parent + continue + } + + return Var{}, false + } +} + +func (e *Env) GetArray(key string, local bool) ([]any, bool) { + v, ok := e.Get(key, local) + if !ok { + return nil, false + } + arr, ok := v.Value.([]any) + return arr, ok +} + +// Set will create a new enviroment if this new value shadows an existing one. +func (e *Env) Set(key string, value Var) *Env { + p := e + if _, ok := e.Get(key, false); ok { + p = NewEnv(e) + } + p.env[key] = value + return p +} + +func (e *Env) Update(key string, value any) error { + p := e + for { + if v, ok := p.env[key]; ok { + v.Value = value + p.env[key] = v + return nil + } + + if p.parent != nil { + p = p.parent + continue + } + + // should not happen + return fmt.Errorf("variable not found '%s'", key) + } +} + +func (e *Env) string(pad string) string { + b := strings.Builder{} + + b.WriteString(fmt.Sprintf("%senv {\n", pad)) + for k, v := range e.env { + b.WriteString(fmt.Sprintf("%s%s = %v\n", pad+" ", k, v)) + } + if e.parent != nil { + b.WriteString(e.parent.string(pad + " ")) + } + b.WriteString(fmt.Sprintf("%s}", pad)) + + if pad != "" { + b.WriteRune('\n') + } + + return b.String() +} + +func (e Env) String() string { + return e.string("") +} + +func (e *Env) Copy() *Env { + ne := NewEnv(e.parent) + for k, v := range e.env { + val := Var{Value: v.Value, Loc: v.Loc} + // FIXME: structs + switch v := val.Value.(type) { + // update the closures to the new environment + case Function: + v.closure = ne + val.Value = v + // copy the array + case []any: + arr := make([]any, len(v)) + copy(arr, v) + val.Value = arr + } + ne.env[k] = val + } + + return ne +} + +type Struct struct { + name string + env *Env + inst bool +} + +func (s Struct) String() string { + if s.inst { + return fmt.Sprintf("", s.name) + } else { + return fmt.Sprintf("struct %s", s.name) + } +} + +func (s Struct) Instance() Struct { + return Struct{name: s.name, env: s.env.Copy(), inst: true} +} + +type PanicJumpType int + +const ( + PanicJumpReturn PanicJumpType = iota + PanicJumpTailCall + PanicJumpContinue + PanicJumpBreak +) + +type PanicJump struct { + typ PanicJumpType + value any + loc tokens.Location +} + +type Interpreter struct { + env *Env + // location of the last node evaluated + last tokens.Location + fError bool +} + +func NewInterpreter() Interpreter { + global := NewEnv(nil) + for k, v := range builtInFuncs { + // not shadowing any value + global.Set(k, Var{Value: v}) + } + return Interpreter{env: global} +} + +func (i *Interpreter) Interp(expr any) (val any, err error) { + val, err = i.evaluate(expr) + if err != nil { + return nil, err + } + return val, nil +} + +func (i *Interpreter) isTrue(expr any) bool { + switch v := expr.(type) { + case ast.Number: + return v != 0 + case string: + return v != "" + case bool: + return v + case Function, Callable: + return true + default: + return false + } +} + +func (i *Interpreter) unaryNumeric(op tokens.Token, right any) (ast.Number, error) { + if vr, ok := right.(ast.Number); ok { + return vr, nil + } + + // shouldn't happen + return 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operand") +} + +func (i *Interpreter) binaryNumeric(left any, op tokens.Token, right any) (ast.Number, ast.Number, error) { + vl, okl := left.(ast.Number) + vr, okr := right.(ast.Number) + if okl && okr { + return vl, vr, nil + } + + // shouldn't happen + return 0, 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operands") +} + +func (i *Interpreter) evaluate(expr any) (any, error) { + switch v := expr.(type) { + case ast.Var: + i.last = v.Name.Loc + return i.declare(v) + case ast.Variable: + i.last = v.Name.Loc + return i.variable(v) + case ast.Assign: + return i.assignment(v) + case ast.Binary: + i.last = v.Op.Loc + return i.binaryExpr(v) + case ast.Unary: + i.last = v.Op.Loc + return i.unaryExpr(v) + case ast.Literal: + i.last = v.Value.Loc + return i.literalExpr(v) + case ast.Group: + return i.groupExpr(v) + case ast.Block: + return i.block(v) + case ast.IfElse: + return i.ifElse(v) + case ast.For: + return i.forStmt(v) + case ast.ForIn: + return i.forInStmt(v) + case ast.Call: + i.last = v.Loc + return i.call(v) + case ast.Func: + i.last = v.Name.Loc + return i.fun(v) + case ast.Return: + return i.returnStmt(v) + case ast.Continue: + return i.continueStmt(v) + case ast.Break: + return i.breakStmt(v) + case ast.ExprList: + return i.exprList(v) + case ast.Struct: + return i.strct(v) + case ast.GetExpr: + return i.getExpr(v) + default: + // XXX + return nil, errors.Error{ + Code: errors.Unimplemented, + Message: fmt.Sprintf("evaluation unimplemented: %s", v), + } + } +} + +func (i *Interpreter) defVal(typ ast.Type) any { + switch typ.Value.Id { + case tokens.TNumber: + return ast.Number(0) + case tokens.TBool: + return false + case tokens.TString: + return "" + case tokens.TArray: + arr := make([]any, *typ.Len) + for n := range arr { + arr[n] = i.defVal(*typ.Ret) + } + return arr + case tokens.TStruct: + val, _ := i.env.Get(typ.Value.Value, false) + strct := val.Value.(Struct) + return strct.Instance() + default: + return nil + } +} + +func (i *Interpreter) declare(v ast.Var) (any, error) { + var initv any + var err error + + if v.Type.Value.Id == tokens.TArray && *v.Type.Len <= 0 { + return nil, errors.NewError(errors.InvalidLength, v.Name.Loc, "invalid array length") + } + + if v.Initv != nil { + initv, err = i.evaluate(v.Initv) + if err != nil { + return 0, err + } + + i.env = i.env.Set(v.Name.Value, Var{initv, v.Name.Loc}) + return initv, nil + } else { + // defaults + initv = i.defVal(v.Type) + value := Var{initv, v.Name.Loc} + i.env = i.env.Set(v.Name.Value, value) + return initv, nil + } +} + +func (i *Interpreter) variable(v ast.Variable) (any, error) { + if val, ok := i.env.Get(v.Name.Value, false); ok { + value := val.Value + if v.Index != nil { + index, err := i.evaluate(v.Index) + if err != nil { + return nil, err + } + idx, ok := index.(ast.Number) + if !ok || idx < 0 || int(idx) >= *v.Type.Len { + return nil, errors.NewError(errors.InvalidIndex, v.Name.Loc, "invalid index for", v.Type.String()) + } + arr, ok := i.env.GetArray(v.Name.Value, false) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "expected array") + } + value = arr[idx] + } + return value, nil + } + // shouldn't happen + return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "undefined variable", v.String()) +} + +func (i *Interpreter) assignment(v ast.Assign) (any, error) { + env := i.env + vLeft := v.Left + + if getExpr, ok := v.Left.(ast.GetExpr); ok { + obj, err := i.evaluate(getExpr.Object) + if err != nil { + return nil, err + } + strct := obj.(Struct) + env = strct.env + vLeft = getExpr.Expr + } + + left, ok := vLeft.(ast.Variable) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected variable in assignation") + } + right, err := i.evaluate(v.Right) + if err != nil { + return nil, err + } + + if left.Type.Value.Id == tokens.TArray && v.Right.Resolve().Value.Id != tokens.TArray { + if left.Index == nil { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected index in assignation") + } + index, err := i.evaluate(left.Index) + if err != nil { + return nil, err + } + idx, ok := index.(ast.Number) + if !ok || idx < 0 || int(idx) >= *left.Type.Len { + return nil, errors.NewError(errors.InvalidIndex, v.Loc, "invalid index for", left.Type.String()) + } + arr, ok := i.env.GetArray(left.Name.Value, false) + if !ok { + return nil, errors.NewError(errors.Unexpected, v.Loc, "expected array") + } + arr[idx] = right + right = arr + } + + err = env.Update(left.Name.Value, right) + if err != nil { + return nil, errors.NewError(errors.InvalidValue, left.Name.Loc, err.Error()) + } + return right, nil +} + +func (i *Interpreter) binaryExpr(expr ast.Binary) (any, error) { + // first lazy operators + if expr.Op.Id == tokens.And || expr.Op.Id == tokens.Or { + left, err := i.evaluate(expr.Left) + if err != nil { + return nil, err + } + if expr.Op.Id == tokens.And { + if !i.isTrue(left) { + return false, nil + } + } else { + if i.isTrue(left) { + return true, nil + } + } + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + return i.isTrue(right), nil + } + + left, err := i.evaluate(expr.Left) + if err != nil { + return nil, err + } + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + + // equality compares two values as long as they are the same type + switch expr.Op.Id { + case tokens.Ne: + return right != left, nil + case tokens.Eq: + return right == left, nil + } + + vl, vr, err := i.binaryNumeric(left, expr.Op, right) + if err != nil { + return 0, err + } + + switch expr.Op.Id { + case tokens.Sub: + return vl - vr, nil + case tokens.Add: + return vl + vr, nil + case tokens.Mul: + return vl * vr, nil + case tokens.Div: + if vr == 0 { + return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") + } + return vl / vr, nil + case tokens.Mod: + if vr == 0 { + return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") + } + return vl % vr, nil + case tokens.BitAnd: + return vl & vr, nil + case tokens.BitShl: + return vl << vr, nil + case tokens.BitShr: + return vl >> vr, nil + case tokens.BitOr: + return vl | vr, nil + case tokens.BitXor: + return vl ^ vr, nil + case tokens.Lt: + return vl < vr, nil + case tokens.Le: + return vl <= vr, nil + case tokens.Gt: + return vl > vr, nil + case tokens.Ge: + return vl >= vr, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) + } + +} + +func (i *Interpreter) unaryExpr(expr ast.Unary) (any, error) { + if expr.Op.Id == tokens.TestE { + i.fError = false + } + + right, err := i.evaluate(expr.Right) + if err != nil { + return nil, err + } + + switch expr.Op.Id { + case tokens.TestE: + return i.fError, nil + case tokens.Sub: + vr, err := i.unaryNumeric(expr.Op, right) + if err != nil { + return nil, err + } + return -vr, nil + case tokens.Neg: + vr, err := i.unaryNumeric(expr.Op, right) + if err != nil { + return nil, err + } + return ^vr, nil + case tokens.Not: + vr := i.isTrue(right) + return !vr, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) + } +} + +func (i *Interpreter) literalExpr(expr ast.Literal) (any, error) { + switch expr.Value.Id { + case tokens.Number: + return expr.Numval, nil + case tokens.String: + return expr.Value.Value, nil + case tokens.True: + return true, nil + case tokens.False: + return false, nil + case tokens.None: + return nil, nil + default: + return nil, errors.NewError(errors.Unimplemented, expr.Value.Loc, "unimplemented type", expr.Value.String()) + } +} + +func (i *Interpreter) groupExpr(expr ast.Group) (any, error) { + return i.evaluate(expr.Expr) +} + +func (i *Interpreter) blockNoEnv(block ast.Block) (any, error) { + var v any + var err error + for _, expr := range block.Stmts { + v, err = i.evaluate(expr) + if err != nil { + return nil, err + } + } + return v, nil +} + +func (i *Interpreter) block(block ast.Block) (any, error) { + pEnv := i.env + i.env = NewEnv(pEnv) + defer func() { + i.env = pEnv + }() + + return i.blockNoEnv(block) +} + +func (i *Interpreter) ifElse(ifElse ast.IfElse) (any, error) { + cond, err := i.evaluate(ifElse.Cond) + if err != nil { + return nil, err + } + if i.isTrue(cond) { + return i.evaluate(ifElse.True) + } else { + return i.evaluate(ifElse.False) + } +} + +func (i *Interpreter) forEval(fcond func() (bool, error), stmts ast.Expr) (any, error) { + var last any + + for { + end, result, err := func() (end bool, result any, err error) { + // handle "interruptions" + defer func() { + if r := recover(); r != nil { + if val, ok := r.(*PanicJump); ok { + if val.typ == PanicJumpContinue { + end = false + return + } + if val.typ == PanicJumpBreak { + end = true + result = last + err = nil + return + } + } + panic(r) + } + }() + for { + if fcond != nil { + cond, err := fcond() + if err != nil { + return true, nil, err + } + if !cond { + return true, last, nil + } + } + + last, err = i.evaluate(stmts) + if err != nil { + return true, nil, err + } + } + }() + if end { + return result, err + } + } +} + +func (i *Interpreter) forStmt(forStmt ast.For) (any, error) { + if forStmt.Cond == nil { + return i.forEval(nil, forStmt.Stmts) + } else { + return i.forEval(func() (bool, error) { + cond, err := i.evaluate(forStmt.Cond) + if err != nil { + return false, err + } + return i.isTrue(cond), nil + }, forStmt.Stmts) + } +} + +func (i *Interpreter) forInStmt(forInStmt ast.ForIn) (any, error) { + expr, err := i.evaluate(forInStmt.Expr) + if err != nil { + return nil, err + } + + index := 0 + arr := expr.([]any) + + pEnv := i.env + i.env = NewEnv(pEnv) + defer func() { + i.env = pEnv + }() + i.env = i.env.Set(forInStmt.Name.Value, Var{Loc: forInStmt.Name.Loc}) + + return i.forEval(func() (bool, error) { + if index == len(arr) { + return false, nil + } + i.env.Update(forInStmt.Name.Value, arr[index]) + index++ + return true, nil + }, forInStmt.Stmts) +} + +func (i *Interpreter) evalArgs(name string, loc tokens.Location, params []ast.Type, args []ast.Expr) ([]any, error) { + vals := make([]any, 0, 16) + for _, a := range args { + arg, err := i.evaluate(a) + if err != nil { + return nil, err + } + vals = append(vals, arg) + } + + return vals, nil +} + +func (i *Interpreter) call(call ast.Call) (result any, err error) { + callee, err := i.evaluate(call.Callee) + if err != nil { + return nil, err + } + + fun, ok := callee.(Callable) + if !ok { + return nil, errors.NewError(errors.NotCallable, call.Loc, "value is not callable") + } + + args, err := i.evalArgs(fun.Name(), call.Loc, fun.Params(), call.Args) + if err != nil { + return nil, err + } + + // handle return via panic call + defer func() { + if r := recover(); r != nil { + if val, ok := r.(*PanicJump); ok && val.typ == PanicJumpReturn { + result = val.value + err = nil + } else { + // won't be handled here + panic(r) + } + } + }() + + _, err = fun.Call(i, args, call.Loc) + if err != nil { + e := errors.NewError(errors.CallError, call.Loc, "error calling", fun.Name()).(errors.Error) + e.Err = err + return nil, e + } + + // clear return if there's no return value + if fun.Ret() == nil { + return nil, nil + } + + return nil, errors.NewError(errors.NoReturn, i.last, "no return value in", fun.Name(), "expected", fun.Ret().String()) +} + +func (i *Interpreter) fun(v ast.Func) (any, error) { + callable := Function{fun: v, closure: i.env} + i.env = i.env.Set(v.Name.Value, Var{callable, v.Name.Loc}) + return nil, nil +} + +func (i *Interpreter) strct(v ast.Struct) (any, error) { + pEnv := i.env + sEnv := NewEnv(pEnv) + i.env = sEnv + defer func() { + i.env = pEnv + strct := Struct{name: v.Name.Value, env: sEnv} + i.env = i.env.Set(v.Name.Value, Var{Value: strct, Loc: v.Name.Loc}) + }() + + _, err := i.blockNoEnv(v.Body) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (i *Interpreter) getExpr(v ast.GetExpr) (any, error) { + obj, err := i.evaluate(v.Object) + if err != nil { + return nil, err + } + + if val, ok := obj.(Struct); ok { + name := v.Expr.(ast.Variable) + if name.Index == nil { + expr, _ := val.env.Get(name.Name.Value, true) + return expr.Value, nil + } + + arr, _ := val.env.GetArray(name.Name.Value, true) + index, err := i.evaluate(name.Index) + if err != nil { + return nil, err + } + return arr[index.(ast.Number)], nil + } + + // shouldn't happen + return nil, errors.NewError(errors.Unimplemented, v.Resolve().Value.Loc, "unimplemented get-expr") +} + +func (i *Interpreter) returnStmt(v ast.Return) (any, error) { + var val any + + if v.Value != nil { + // could be a tail call we could optimize + if call, ok := v.Value.(ast.Call); ok { + i.fError = v.Error + panic(&PanicJump{typ: PanicJumpTailCall, value: call, loc: v.Loc}) + } + + var err error + val, err = i.evaluate(v.Value) + if err != nil { + return nil, err + } + } + + i.fError = v.Error + panic(&PanicJump{typ: PanicJumpReturn, value: val, loc: v.Loc}) +} + +func (i *Interpreter) continueStmt(v ast.Continue) (any, error) { + panic(&PanicJump{typ: PanicJumpContinue, loc: v.Loc}) +} + +func (i *Interpreter) breakStmt(v ast.Break) (any, error) { + panic(&PanicJump{typ: PanicJumpBreak, loc: v.Loc}) +} + +func (i *Interpreter) exprList(v ast.ExprList) (any, error) { + // XXX: should we set last? + vals := make([]any, len(v.Exprs)) + for n, expr := range v.Exprs { + val, err := i.evaluate(expr) + if err != nil { + return nil, err + } + vals[n] = val + } + return vals, nil +} -- cgit v1.2.3