package interpreter import ( "fmt" "strings" "usebox.net/lang/ast" "usebox.net/lang/errors" "usebox.net/lang/tokens" ) type Var struct { Value any Loc tokens.Location } type Env struct { env map[string]Var parent *Env } func NewEnv(parent *Env) *Env { return &Env{make(map[string]Var), parent} } func (e *Env) Get(key string, local bool) (Var, bool) { p := e for { if v, ok := p.env[key]; ok { return v, true } if !local && p.parent != nil { p = p.parent continue } return Var{}, false } } func (e *Env) GetArray(key string, local bool) ([]any, bool) { v, ok := e.Get(key, local) if !ok { return nil, false } arr, ok := v.Value.([]any) return arr, ok } // Set will create a new enviroment if this new value shadows an existing one. func (e *Env) Set(key string, value Var) *Env { p := e if _, ok := e.Get(key, false); ok { p = NewEnv(e) } p.env[key] = value return p } func (e *Env) Update(key string, value any) error { p := e for { if v, ok := p.env[key]; ok { v.Value = value p.env[key] = v return nil } if p.parent != nil { p = p.parent continue } // should not happen return fmt.Errorf("variable not found '%s'", key) } } func (e *Env) string(pad string) string { b := strings.Builder{} b.WriteString(fmt.Sprintf("%senv {\n", pad)) for k, v := range e.env { b.WriteString(fmt.Sprintf("%s%s = %v\n", pad+" ", k, v)) } if e.parent != nil { b.WriteString(e.parent.string(pad + " ")) } b.WriteString(fmt.Sprintf("%s}", pad)) if pad != "" { b.WriteRune('\n') } return b.String() } func (e Env) String() string { return e.string("") } func (e *Env) Copy() *Env { ne := NewEnv(e.parent) for k, v := range e.env { val := Var{Value: v.Value, Loc: v.Loc} // FIXME: structs switch v := val.Value.(type) { // update the closures to the new environment case Function: v.closure = ne val.Value = v // copy the array case []any: arr := make([]any, len(v)) copy(arr, v) val.Value = arr } ne.env[k] = val } return ne } type Struct struct { name string env *Env inst bool } func (s Struct) String() string { if s.inst { return fmt.Sprintf("", s.name) } else { return fmt.Sprintf("struct %s", s.name) } } func (s Struct) Instance() Struct { return Struct{name: s.name, env: s.env.Copy(), inst: true} } type PanicJumpType int const ( PanicJumpReturn PanicJumpType = iota PanicJumpTailCall PanicJumpContinue PanicJumpBreak ) type PanicJump struct { typ PanicJumpType value any loc tokens.Location } type Interpreter struct { env *Env // location of the last node evaluated last tokens.Location fError bool } func NewInterpreter() Interpreter { global := NewEnv(nil) for k, v := range builtInFuncs { // not shadowing any value global.Set(k, Var{Value: v}) } return Interpreter{env: global} } func (i *Interpreter) Interp(expr any) (val any, err error) { val, err = i.evaluate(expr) if err != nil { return nil, err } return val, nil } func (i *Interpreter) isTrue(expr any) bool { switch v := expr.(type) { case ast.Number: return v != 0 case string: return v != "" case bool: return v case Function, Callable: return true default: return false } } func (i *Interpreter) unaryNumeric(op tokens.Token, right any) (ast.Number, error) { if vr, ok := right.(ast.Number); ok { return vr, nil } // shouldn't happen return 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operand") } func (i *Interpreter) binaryNumeric(left any, op tokens.Token, right any) (ast.Number, ast.Number, error) { vl, okl := left.(ast.Number) vr, okr := right.(ast.Number) if okl && okr { return vl, vr, nil } // shouldn't happen return 0, 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operands") } func (i *Interpreter) evaluate(expr any) (any, error) { switch v := expr.(type) { case ast.Var: i.last = v.Name.Loc return i.declare(v) case ast.Variable: i.last = v.Name.Loc return i.variable(v) case ast.Assign: return i.assignment(v) case ast.Binary: i.last = v.Op.Loc return i.binaryExpr(v) case ast.Unary: i.last = v.Op.Loc return i.unaryExpr(v) case ast.Literal: i.last = v.Value.Loc return i.literalExpr(v) case ast.Group: return i.groupExpr(v) case ast.Block: return i.block(v) case ast.IfElse: return i.ifElse(v) case ast.For: return i.forStmt(v) case ast.ForIn: return i.forInStmt(v) case ast.Call: i.last = v.Loc return i.call(v) case ast.Func: i.last = v.Name.Loc return i.fun(v) case ast.Return: return i.returnStmt(v) case ast.Continue: return i.continueStmt(v) case ast.Break: return i.breakStmt(v) case ast.ExprList: return i.exprList(v) case ast.Struct: return i.strct(v) case ast.GetExpr: return i.getExpr(v) default: // XXX return nil, errors.Error{ Code: errors.Unimplemented, Message: fmt.Sprintf("evaluation unimplemented: %s", v), } } } func (i *Interpreter) defVal(typ ast.Type) any { switch typ.Value.Id { case tokens.TNumber: return ast.Number(0) case tokens.TBool: return false case tokens.TString: return "" case tokens.TArray: arr := make([]any, *typ.Len) for n := range arr { arr[n] = i.defVal(*typ.Ret) } return arr case tokens.TStruct: val, _ := i.env.Get(typ.Value.Value, false) strct := val.Value.(Struct) return strct.Instance() default: return nil } } func (i *Interpreter) declare(v ast.Var) (any, error) { var initv any var err error if v.Type.Value.Id == tokens.TArray && *v.Type.Len <= 0 { return nil, errors.NewError(errors.InvalidLength, v.Name.Loc, "invalid array length") } if v.Initv != nil { initv, err = i.evaluate(v.Initv) if err != nil { return 0, err } i.env = i.env.Set(v.Name.Value, Var{initv, v.Name.Loc}) return initv, nil } else { // defaults initv = i.defVal(v.Type) value := Var{initv, v.Name.Loc} i.env = i.env.Set(v.Name.Value, value) return initv, nil } } func (i *Interpreter) variable(v ast.Variable) (any, error) { if val, ok := i.env.Get(v.Name.Value, false); ok { value := val.Value if v.Index != nil { index, err := i.evaluate(v.Index) if err != nil { return nil, err } idx, ok := index.(ast.Number) if !ok || idx < 0 || int(idx) >= *v.Type.Len { return nil, errors.NewError(errors.InvalidIndex, v.Name.Loc, "invalid index for", v.Type.String()) } arr, ok := i.env.GetArray(v.Name.Value, false) if !ok { return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "expected array") } value = arr[idx] } return value, nil } // shouldn't happen return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "undefined variable", v.String()) } func (i *Interpreter) assignment(v ast.Assign) (any, error) { env := i.env vLeft := v.Left if getExpr, ok := v.Left.(ast.GetExpr); ok { obj, err := i.evaluate(getExpr.Object) if err != nil { return nil, err } strct := obj.(Struct) env = strct.env vLeft = getExpr.Expr } left, ok := vLeft.(ast.Variable) if !ok { return nil, errors.NewError(errors.Unexpected, v.Loc, "expected variable in assignation") } right, err := i.evaluate(v.Right) if err != nil { return nil, err } if left.Type.Value.Id == tokens.TArray && v.Right.Resolve().Value.Id != tokens.TArray { if left.Index == nil { return nil, errors.NewError(errors.Unexpected, v.Loc, "expected index in assignation") } index, err := i.evaluate(left.Index) if err != nil { return nil, err } idx, ok := index.(ast.Number) if !ok || idx < 0 || int(idx) >= *left.Type.Len { return nil, errors.NewError(errors.InvalidIndex, v.Loc, "invalid index for", left.Type.String()) } arr, ok := i.env.GetArray(left.Name.Value, false) if !ok { return nil, errors.NewError(errors.Unexpected, v.Loc, "expected array") } arr[idx] = right right = arr } err = env.Update(left.Name.Value, right) if err != nil { return nil, errors.NewError(errors.InvalidValue, left.Name.Loc, err.Error()) } return right, nil } func (i *Interpreter) binaryExpr(expr ast.Binary) (any, error) { // first lazy operators if expr.Op.Id == tokens.And || expr.Op.Id == tokens.Or { left, err := i.evaluate(expr.Left) if err != nil { return nil, err } if expr.Op.Id == tokens.And { if !i.isTrue(left) { return false, nil } } else { if i.isTrue(left) { return true, nil } } right, err := i.evaluate(expr.Right) if err != nil { return nil, err } return i.isTrue(right), nil } left, err := i.evaluate(expr.Left) if err != nil { return nil, err } right, err := i.evaluate(expr.Right) if err != nil { return nil, err } // equality compares two values as long as they are the same type switch expr.Op.Id { case tokens.Ne: return right != left, nil case tokens.Eq: return right == left, nil } vl, vr, err := i.binaryNumeric(left, expr.Op, right) if err != nil { return 0, err } switch expr.Op.Id { case tokens.Sub: return vl - vr, nil case tokens.Add: return vl + vr, nil case tokens.Mul: return vl * vr, nil case tokens.Div: if vr == 0 { return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") } return vl / vr, nil case tokens.Mod: if vr == 0 { return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero") } return vl % vr, nil case tokens.BitAnd: return vl & vr, nil case tokens.BitShl: return vl << vr, nil case tokens.BitShr: return vl >> vr, nil case tokens.BitOr: return vl | vr, nil case tokens.BitXor: return vl ^ vr, nil case tokens.Lt: return vl < vr, nil case tokens.Le: return vl <= vr, nil case tokens.Gt: return vl > vr, nil case tokens.Ge: return vl >= vr, nil default: return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) } } func (i *Interpreter) unaryExpr(expr ast.Unary) (any, error) { if expr.Op.Id == tokens.TestE { i.fError = false } right, err := i.evaluate(expr.Right) if err != nil { return nil, err } switch expr.Op.Id { case tokens.TestE: return i.fError, nil case tokens.Sub: vr, err := i.unaryNumeric(expr.Op, right) if err != nil { return nil, err } return -vr, nil case tokens.Neg: vr, err := i.unaryNumeric(expr.Op, right) if err != nil { return nil, err } return ^vr, nil case tokens.Not: vr := i.isTrue(right) return !vr, nil default: return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value) } } func (i *Interpreter) literalExpr(expr ast.Literal) (any, error) { switch expr.Value.Id { case tokens.Number: return expr.Numval, nil case tokens.String: return expr.Value.Value, nil case tokens.True: return true, nil case tokens.False: return false, nil case tokens.None: return nil, nil default: return nil, errors.NewError(errors.Unimplemented, expr.Value.Loc, "unimplemented type", expr.Value.String()) } } func (i *Interpreter) groupExpr(expr ast.Group) (any, error) { return i.evaluate(expr.Expr) } func (i *Interpreter) blockNoEnv(block ast.Block) (any, error) { var v any var err error for _, expr := range block.Stmts { v, err = i.evaluate(expr) if err != nil { return nil, err } } return v, nil } func (i *Interpreter) block(block ast.Block) (any, error) { pEnv := i.env i.env = NewEnv(pEnv) defer func() { i.env = pEnv }() return i.blockNoEnv(block) } func (i *Interpreter) ifElse(ifElse ast.IfElse) (any, error) { cond, err := i.evaluate(ifElse.Cond) if err != nil { return nil, err } if i.isTrue(cond) { return i.evaluate(ifElse.True) } else { return i.evaluate(ifElse.False) } } func (i *Interpreter) forEval(fcond func() (bool, error), stmts ast.Expr) (any, error) { var last any for { end, result, err := func() (end bool, result any, err error) { // handle "interruptions" defer func() { if r := recover(); r != nil { if val, ok := r.(*PanicJump); ok { if val.typ == PanicJumpContinue { end = false return } if val.typ == PanicJumpBreak { end = true result = last err = nil return } } panic(r) } }() for { if fcond != nil { cond, err := fcond() if err != nil { return true, nil, err } if !cond { return true, last, nil } } last, err = i.evaluate(stmts) if err != nil { return true, nil, err } } }() if end { return result, err } } } func (i *Interpreter) forStmt(forStmt ast.For) (any, error) { if forStmt.Cond == nil { return i.forEval(nil, forStmt.Stmts) } else { return i.forEval(func() (bool, error) { cond, err := i.evaluate(forStmt.Cond) if err != nil { return false, err } return i.isTrue(cond), nil }, forStmt.Stmts) } } func (i *Interpreter) forInStmt(forInStmt ast.ForIn) (any, error) { expr, err := i.evaluate(forInStmt.Expr) if err != nil { return nil, err } index := 0 arr := expr.([]any) pEnv := i.env i.env = NewEnv(pEnv) defer func() { i.env = pEnv }() i.env = i.env.Set(forInStmt.Name.Value, Var{Loc: forInStmt.Name.Loc}) return i.forEval(func() (bool, error) { if index == len(arr) { return false, nil } i.env.Update(forInStmt.Name.Value, arr[index]) index++ return true, nil }, forInStmt.Stmts) } func (i *Interpreter) evalArgs(name string, loc tokens.Location, params []ast.Type, args []ast.Expr) ([]any, error) { vals := make([]any, 0, 16) for _, a := range args { arg, err := i.evaluate(a) if err != nil { return nil, err } vals = append(vals, arg) } return vals, nil } func (i *Interpreter) call(call ast.Call) (result any, err error) { callee, err := i.evaluate(call.Callee) if err != nil { return nil, err } fun, ok := callee.(Callable) if !ok { return nil, errors.NewError(errors.NotCallable, call.Loc, "value is not callable") } args, err := i.evalArgs(fun.Name(), call.Loc, fun.Params(), call.Args) if err != nil { return nil, err } // handle return via panic call defer func() { if r := recover(); r != nil { if val, ok := r.(*PanicJump); ok && val.typ == PanicJumpReturn { result = val.value err = nil } else { // won't be handled here panic(r) } } }() _, err = fun.Call(i, args, call.Loc) if err != nil { e := errors.NewError(errors.CallError, call.Loc, "error calling", fun.Name()).(errors.Error) e.Err = err return nil, e } // clear return if there's no return value if fun.Ret() == nil { return nil, nil } return nil, errors.NewError(errors.NoReturn, i.last, "no return value in", fun.Name(), "expected", fun.Ret().String()) } func (i *Interpreter) fun(v ast.Func) (any, error) { callable := Function{fun: v, closure: i.env} i.env = i.env.Set(v.Name.Value, Var{callable, v.Name.Loc}) return nil, nil } func (i *Interpreter) strct(v ast.Struct) (any, error) { pEnv := i.env sEnv := NewEnv(pEnv) i.env = sEnv defer func() { i.env = pEnv strct := Struct{name: v.Name.Value, env: sEnv} i.env = i.env.Set(v.Name.Value, Var{Value: strct, Loc: v.Name.Loc}) }() _, err := i.blockNoEnv(v.Body) if err != nil { return nil, err } return nil, nil } func (i *Interpreter) getExpr(v ast.GetExpr) (any, error) { obj, err := i.evaluate(v.Object) if err != nil { return nil, err } if val, ok := obj.(Struct); ok { name := v.Expr.(ast.Variable) if name.Index == nil { expr, _ := val.env.Get(name.Name.Value, true) return expr.Value, nil } arr, _ := val.env.GetArray(name.Name.Value, true) index, err := i.evaluate(name.Index) if err != nil { return nil, err } return arr[index.(ast.Number)], nil } // shouldn't happen return nil, errors.NewError(errors.Unimplemented, v.Resolve().Value.Loc, "unimplemented get-expr") } func (i *Interpreter) returnStmt(v ast.Return) (any, error) { var val any if v.Value != nil { // could be a tail call we could optimize if call, ok := v.Value.(ast.Call); ok { i.fError = v.Error panic(&PanicJump{typ: PanicJumpTailCall, value: call, loc: v.Loc}) } var err error val, err = i.evaluate(v.Value) if err != nil { return nil, err } } i.fError = v.Error panic(&PanicJump{typ: PanicJumpReturn, value: val, loc: v.Loc}) } func (i *Interpreter) continueStmt(v ast.Continue) (any, error) { panic(&PanicJump{typ: PanicJumpContinue, loc: v.Loc}) } func (i *Interpreter) breakStmt(v ast.Break) (any, error) { panic(&PanicJump{typ: PanicJumpBreak, loc: v.Loc}) } func (i *Interpreter) exprList(v ast.ExprList) (any, error) { // XXX: should we set last? vals := make([]any, len(v.Exprs)) for n, expr := range v.Exprs { val, err := i.evaluate(expr) if err != nil { return nil, err } vals[n] = val } return vals, nil }