aboutsummaryrefslogtreecommitdiff
path: root/interpreter/interpreter.go
diff options
context:
space:
mode:
authorJuan J. Martinez <jjm@usebox.net>2022-07-18 07:45:58 +0100
committerJuan J. Martinez <jjm@usebox.net>2022-07-18 07:45:58 +0100
commit8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642 (patch)
treec53977d1284347bb1d5963ddb4dc7723c40c6e55 /interpreter/interpreter.go
downloadmicro-lang-8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642.tar.gz
micro-lang-8bb321f8b032dfaeffbe3d1b8dfeb215c12d3642.zip
First public release
Diffstat (limited to 'interpreter/interpreter.go')
-rw-r--r--interpreter/interpreter.go801
1 files changed, 801 insertions, 0 deletions
diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go
new file mode 100644
index 0000000..8d5b7a8
--- /dev/null
+++ b/interpreter/interpreter.go
@@ -0,0 +1,801 @@
+package interpreter
+
+import (
+ "fmt"
+ "strings"
+
+ "usebox.net/lang/ast"
+ "usebox.net/lang/errors"
+ "usebox.net/lang/tokens"
+)
+
+type Var struct {
+ Value any
+ Loc tokens.Location
+}
+
+type Env struct {
+ env map[string]Var
+ parent *Env
+}
+
+func NewEnv(parent *Env) *Env {
+ return &Env{make(map[string]Var), parent}
+}
+
+func (e *Env) Get(key string, local bool) (Var, bool) {
+ p := e
+ for {
+ if v, ok := p.env[key]; ok {
+ return v, true
+ }
+
+ if !local && p.parent != nil {
+ p = p.parent
+ continue
+ }
+
+ return Var{}, false
+ }
+}
+
+func (e *Env) GetArray(key string, local bool) ([]any, bool) {
+ v, ok := e.Get(key, local)
+ if !ok {
+ return nil, false
+ }
+ arr, ok := v.Value.([]any)
+ return arr, ok
+}
+
+// Set will create a new enviroment if this new value shadows an existing one.
+func (e *Env) Set(key string, value Var) *Env {
+ p := e
+ if _, ok := e.Get(key, false); ok {
+ p = NewEnv(e)
+ }
+ p.env[key] = value
+ return p
+}
+
+func (e *Env) Update(key string, value any) error {
+ p := e
+ for {
+ if v, ok := p.env[key]; ok {
+ v.Value = value
+ p.env[key] = v
+ return nil
+ }
+
+ if p.parent != nil {
+ p = p.parent
+ continue
+ }
+
+ // should not happen
+ return fmt.Errorf("variable not found '%s'", key)
+ }
+}
+
+func (e *Env) string(pad string) string {
+ b := strings.Builder{}
+
+ b.WriteString(fmt.Sprintf("%senv {\n", pad))
+ for k, v := range e.env {
+ b.WriteString(fmt.Sprintf("%s%s = %v\n", pad+" ", k, v))
+ }
+ if e.parent != nil {
+ b.WriteString(e.parent.string(pad + " "))
+ }
+ b.WriteString(fmt.Sprintf("%s}", pad))
+
+ if pad != "" {
+ b.WriteRune('\n')
+ }
+
+ return b.String()
+}
+
+func (e Env) String() string {
+ return e.string("")
+}
+
+func (e *Env) Copy() *Env {
+ ne := NewEnv(e.parent)
+ for k, v := range e.env {
+ val := Var{Value: v.Value, Loc: v.Loc}
+ // FIXME: structs
+ switch v := val.Value.(type) {
+ // update the closures to the new environment
+ case Function:
+ v.closure = ne
+ val.Value = v
+ // copy the array
+ case []any:
+ arr := make([]any, len(v))
+ copy(arr, v)
+ val.Value = arr
+ }
+ ne.env[k] = val
+ }
+
+ return ne
+}
+
+type Struct struct {
+ name string
+ env *Env
+ inst bool
+}
+
+func (s Struct) String() string {
+ if s.inst {
+ return fmt.Sprintf("<struct %s>", s.name)
+ } else {
+ return fmt.Sprintf("struct %s", s.name)
+ }
+}
+
+func (s Struct) Instance() Struct {
+ return Struct{name: s.name, env: s.env.Copy(), inst: true}
+}
+
+type PanicJumpType int
+
+const (
+ PanicJumpReturn PanicJumpType = iota
+ PanicJumpTailCall
+ PanicJumpContinue
+ PanicJumpBreak
+)
+
+type PanicJump struct {
+ typ PanicJumpType
+ value any
+ loc tokens.Location
+}
+
+type Interpreter struct {
+ env *Env
+ // location of the last node evaluated
+ last tokens.Location
+ fError bool
+}
+
+func NewInterpreter() Interpreter {
+ global := NewEnv(nil)
+ for k, v := range builtInFuncs {
+ // not shadowing any value
+ global.Set(k, Var{Value: v})
+ }
+ return Interpreter{env: global}
+}
+
+func (i *Interpreter) Interp(expr any) (val any, err error) {
+ val, err = i.evaluate(expr)
+ if err != nil {
+ return nil, err
+ }
+ return val, nil
+}
+
+func (i *Interpreter) isTrue(expr any) bool {
+ switch v := expr.(type) {
+ case ast.Number:
+ return v != 0
+ case string:
+ return v != ""
+ case bool:
+ return v
+ case Function, Callable:
+ return true
+ default:
+ return false
+ }
+}
+
+func (i *Interpreter) unaryNumeric(op tokens.Token, right any) (ast.Number, error) {
+ if vr, ok := right.(ast.Number); ok {
+ return vr, nil
+ }
+
+ // shouldn't happen
+ return 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operand")
+}
+
+func (i *Interpreter) binaryNumeric(left any, op tokens.Token, right any) (ast.Number, ast.Number, error) {
+ vl, okl := left.(ast.Number)
+ vr, okr := right.(ast.Number)
+ if okl && okr {
+ return vl, vr, nil
+ }
+
+ // shouldn't happen
+ return 0, 0, errors.NewError(errors.Unexpected, op.Loc, "invalid operands")
+}
+
+func (i *Interpreter) evaluate(expr any) (any, error) {
+ switch v := expr.(type) {
+ case ast.Var:
+ i.last = v.Name.Loc
+ return i.declare(v)
+ case ast.Variable:
+ i.last = v.Name.Loc
+ return i.variable(v)
+ case ast.Assign:
+ return i.assignment(v)
+ case ast.Binary:
+ i.last = v.Op.Loc
+ return i.binaryExpr(v)
+ case ast.Unary:
+ i.last = v.Op.Loc
+ return i.unaryExpr(v)
+ case ast.Literal:
+ i.last = v.Value.Loc
+ return i.literalExpr(v)
+ case ast.Group:
+ return i.groupExpr(v)
+ case ast.Block:
+ return i.block(v)
+ case ast.IfElse:
+ return i.ifElse(v)
+ case ast.For:
+ return i.forStmt(v)
+ case ast.ForIn:
+ return i.forInStmt(v)
+ case ast.Call:
+ i.last = v.Loc
+ return i.call(v)
+ case ast.Func:
+ i.last = v.Name.Loc
+ return i.fun(v)
+ case ast.Return:
+ return i.returnStmt(v)
+ case ast.Continue:
+ return i.continueStmt(v)
+ case ast.Break:
+ return i.breakStmt(v)
+ case ast.ExprList:
+ return i.exprList(v)
+ case ast.Struct:
+ return i.strct(v)
+ case ast.GetExpr:
+ return i.getExpr(v)
+ default:
+ // XXX
+ return nil, errors.Error{
+ Code: errors.Unimplemented,
+ Message: fmt.Sprintf("evaluation unimplemented: %s", v),
+ }
+ }
+}
+
+func (i *Interpreter) defVal(typ ast.Type) any {
+ switch typ.Value.Id {
+ case tokens.TNumber:
+ return ast.Number(0)
+ case tokens.TBool:
+ return false
+ case tokens.TString:
+ return ""
+ case tokens.TArray:
+ arr := make([]any, *typ.Len)
+ for n := range arr {
+ arr[n] = i.defVal(*typ.Ret)
+ }
+ return arr
+ case tokens.TStruct:
+ val, _ := i.env.Get(typ.Value.Value, false)
+ strct := val.Value.(Struct)
+ return strct.Instance()
+ default:
+ return nil
+ }
+}
+
+func (i *Interpreter) declare(v ast.Var) (any, error) {
+ var initv any
+ var err error
+
+ if v.Type.Value.Id == tokens.TArray && *v.Type.Len <= 0 {
+ return nil, errors.NewError(errors.InvalidLength, v.Name.Loc, "invalid array length")
+ }
+
+ if v.Initv != nil {
+ initv, err = i.evaluate(v.Initv)
+ if err != nil {
+ return 0, err
+ }
+
+ i.env = i.env.Set(v.Name.Value, Var{initv, v.Name.Loc})
+ return initv, nil
+ } else {
+ // defaults
+ initv = i.defVal(v.Type)
+ value := Var{initv, v.Name.Loc}
+ i.env = i.env.Set(v.Name.Value, value)
+ return initv, nil
+ }
+}
+
+func (i *Interpreter) variable(v ast.Variable) (any, error) {
+ if val, ok := i.env.Get(v.Name.Value, false); ok {
+ value := val.Value
+ if v.Index != nil {
+ index, err := i.evaluate(v.Index)
+ if err != nil {
+ return nil, err
+ }
+ idx, ok := index.(ast.Number)
+ if !ok || idx < 0 || int(idx) >= *v.Type.Len {
+ return nil, errors.NewError(errors.InvalidIndex, v.Name.Loc, "invalid index for", v.Type.String())
+ }
+ arr, ok := i.env.GetArray(v.Name.Value, false)
+ if !ok {
+ return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "expected array")
+ }
+ value = arr[idx]
+ }
+ return value, nil
+ }
+ // shouldn't happen
+ return nil, errors.NewError(errors.Unexpected, v.Name.Loc, "undefined variable", v.String())
+}
+
+func (i *Interpreter) assignment(v ast.Assign) (any, error) {
+ env := i.env
+ vLeft := v.Left
+
+ if getExpr, ok := v.Left.(ast.GetExpr); ok {
+ obj, err := i.evaluate(getExpr.Object)
+ if err != nil {
+ return nil, err
+ }
+ strct := obj.(Struct)
+ env = strct.env
+ vLeft = getExpr.Expr
+ }
+
+ left, ok := vLeft.(ast.Variable)
+ if !ok {
+ return nil, errors.NewError(errors.Unexpected, v.Loc, "expected variable in assignation")
+ }
+ right, err := i.evaluate(v.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ if left.Type.Value.Id == tokens.TArray && v.Right.Resolve().Value.Id != tokens.TArray {
+ if left.Index == nil {
+ return nil, errors.NewError(errors.Unexpected, v.Loc, "expected index in assignation")
+ }
+ index, err := i.evaluate(left.Index)
+ if err != nil {
+ return nil, err
+ }
+ idx, ok := index.(ast.Number)
+ if !ok || idx < 0 || int(idx) >= *left.Type.Len {
+ return nil, errors.NewError(errors.InvalidIndex, v.Loc, "invalid index for", left.Type.String())
+ }
+ arr, ok := i.env.GetArray(left.Name.Value, false)
+ if !ok {
+ return nil, errors.NewError(errors.Unexpected, v.Loc, "expected array")
+ }
+ arr[idx] = right
+ right = arr
+ }
+
+ err = env.Update(left.Name.Value, right)
+ if err != nil {
+ return nil, errors.NewError(errors.InvalidValue, left.Name.Loc, err.Error())
+ }
+ return right, nil
+}
+
+func (i *Interpreter) binaryExpr(expr ast.Binary) (any, error) {
+ // first lazy operators
+ if expr.Op.Id == tokens.And || expr.Op.Id == tokens.Or {
+ left, err := i.evaluate(expr.Left)
+ if err != nil {
+ return nil, err
+ }
+ if expr.Op.Id == tokens.And {
+ if !i.isTrue(left) {
+ return false, nil
+ }
+ } else {
+ if i.isTrue(left) {
+ return true, nil
+ }
+ }
+ right, err := i.evaluate(expr.Right)
+ if err != nil {
+ return nil, err
+ }
+ return i.isTrue(right), nil
+ }
+
+ left, err := i.evaluate(expr.Left)
+ if err != nil {
+ return nil, err
+ }
+ right, err := i.evaluate(expr.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ // equality compares two values as long as they are the same type
+ switch expr.Op.Id {
+ case tokens.Ne:
+ return right != left, nil
+ case tokens.Eq:
+ return right == left, nil
+ }
+
+ vl, vr, err := i.binaryNumeric(left, expr.Op, right)
+ if err != nil {
+ return 0, err
+ }
+
+ switch expr.Op.Id {
+ case tokens.Sub:
+ return vl - vr, nil
+ case tokens.Add:
+ return vl + vr, nil
+ case tokens.Mul:
+ return vl * vr, nil
+ case tokens.Div:
+ if vr == 0 {
+ return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero")
+ }
+ return vl / vr, nil
+ case tokens.Mod:
+ if vr == 0 {
+ return nil, errors.NewError(errors.InvalidOperation, expr.Op.Loc, "invalid operation: division by zero")
+ }
+ return vl % vr, nil
+ case tokens.BitAnd:
+ return vl & vr, nil
+ case tokens.BitShl:
+ return vl << vr, nil
+ case tokens.BitShr:
+ return vl >> vr, nil
+ case tokens.BitOr:
+ return vl | vr, nil
+ case tokens.BitXor:
+ return vl ^ vr, nil
+ case tokens.Lt:
+ return vl < vr, nil
+ case tokens.Le:
+ return vl <= vr, nil
+ case tokens.Gt:
+ return vl > vr, nil
+ case tokens.Ge:
+ return vl >= vr, nil
+ default:
+ return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value)
+ }
+
+}
+
+func (i *Interpreter) unaryExpr(expr ast.Unary) (any, error) {
+ if expr.Op.Id == tokens.TestE {
+ i.fError = false
+ }
+
+ right, err := i.evaluate(expr.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ switch expr.Op.Id {
+ case tokens.TestE:
+ return i.fError, nil
+ case tokens.Sub:
+ vr, err := i.unaryNumeric(expr.Op, right)
+ if err != nil {
+ return nil, err
+ }
+ return -vr, nil
+ case tokens.Neg:
+ vr, err := i.unaryNumeric(expr.Op, right)
+ if err != nil {
+ return nil, err
+ }
+ return ^vr, nil
+ case tokens.Not:
+ vr := i.isTrue(right)
+ return !vr, nil
+ default:
+ return nil, errors.NewError(errors.Unimplemented, expr.Op.Loc, "unimplemented operator", expr.Op.Value)
+ }
+}
+
+func (i *Interpreter) literalExpr(expr ast.Literal) (any, error) {
+ switch expr.Value.Id {
+ case tokens.Number:
+ return expr.Numval, nil
+ case tokens.String:
+ return expr.Value.Value, nil
+ case tokens.True:
+ return true, nil
+ case tokens.False:
+ return false, nil
+ case tokens.None:
+ return nil, nil
+ default:
+ return nil, errors.NewError(errors.Unimplemented, expr.Value.Loc, "unimplemented type", expr.Value.String())
+ }
+}
+
+func (i *Interpreter) groupExpr(expr ast.Group) (any, error) {
+ return i.evaluate(expr.Expr)
+}
+
+func (i *Interpreter) blockNoEnv(block ast.Block) (any, error) {
+ var v any
+ var err error
+ for _, expr := range block.Stmts {
+ v, err = i.evaluate(expr)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return v, nil
+}
+
+func (i *Interpreter) block(block ast.Block) (any, error) {
+ pEnv := i.env
+ i.env = NewEnv(pEnv)
+ defer func() {
+ i.env = pEnv
+ }()
+
+ return i.blockNoEnv(block)
+}
+
+func (i *Interpreter) ifElse(ifElse ast.IfElse) (any, error) {
+ cond, err := i.evaluate(ifElse.Cond)
+ if err != nil {
+ return nil, err
+ }
+ if i.isTrue(cond) {
+ return i.evaluate(ifElse.True)
+ } else {
+ return i.evaluate(ifElse.False)
+ }
+}
+
+func (i *Interpreter) forEval(fcond func() (bool, error), stmts ast.Expr) (any, error) {
+ var last any
+
+ for {
+ end, result, err := func() (end bool, result any, err error) {
+ // handle "interruptions"
+ defer func() {
+ if r := recover(); r != nil {
+ if val, ok := r.(*PanicJump); ok {
+ if val.typ == PanicJumpContinue {
+ end = false
+ return
+ }
+ if val.typ == PanicJumpBreak {
+ end = true
+ result = last
+ err = nil
+ return
+ }
+ }
+ panic(r)
+ }
+ }()
+ for {
+ if fcond != nil {
+ cond, err := fcond()
+ if err != nil {
+ return true, nil, err
+ }
+ if !cond {
+ return true, last, nil
+ }
+ }
+
+ last, err = i.evaluate(stmts)
+ if err != nil {
+ return true, nil, err
+ }
+ }
+ }()
+ if end {
+ return result, err
+ }
+ }
+}
+
+func (i *Interpreter) forStmt(forStmt ast.For) (any, error) {
+ if forStmt.Cond == nil {
+ return i.forEval(nil, forStmt.Stmts)
+ } else {
+ return i.forEval(func() (bool, error) {
+ cond, err := i.evaluate(forStmt.Cond)
+ if err != nil {
+ return false, err
+ }
+ return i.isTrue(cond), nil
+ }, forStmt.Stmts)
+ }
+}
+
+func (i *Interpreter) forInStmt(forInStmt ast.ForIn) (any, error) {
+ expr, err := i.evaluate(forInStmt.Expr)
+ if err != nil {
+ return nil, err
+ }
+
+ index := 0
+ arr := expr.([]any)
+
+ pEnv := i.env
+ i.env = NewEnv(pEnv)
+ defer func() {
+ i.env = pEnv
+ }()
+ i.env = i.env.Set(forInStmt.Name.Value, Var{Loc: forInStmt.Name.Loc})
+
+ return i.forEval(func() (bool, error) {
+ if index == len(arr) {
+ return false, nil
+ }
+ i.env.Update(forInStmt.Name.Value, arr[index])
+ index++
+ return true, nil
+ }, forInStmt.Stmts)
+}
+
+func (i *Interpreter) evalArgs(name string, loc tokens.Location, params []ast.Type, args []ast.Expr) ([]any, error) {
+ vals := make([]any, 0, 16)
+ for _, a := range args {
+ arg, err := i.evaluate(a)
+ if err != nil {
+ return nil, err
+ }
+ vals = append(vals, arg)
+ }
+
+ return vals, nil
+}
+
+func (i *Interpreter) call(call ast.Call) (result any, err error) {
+ callee, err := i.evaluate(call.Callee)
+ if err != nil {
+ return nil, err
+ }
+
+ fun, ok := callee.(Callable)
+ if !ok {
+ return nil, errors.NewError(errors.NotCallable, call.Loc, "value is not callable")
+ }
+
+ args, err := i.evalArgs(fun.Name(), call.Loc, fun.Params(), call.Args)
+ if err != nil {
+ return nil, err
+ }
+
+ // handle return via panic call
+ defer func() {
+ if r := recover(); r != nil {
+ if val, ok := r.(*PanicJump); ok && val.typ == PanicJumpReturn {
+ result = val.value
+ err = nil
+ } else {
+ // won't be handled here
+ panic(r)
+ }
+ }
+ }()
+
+ _, err = fun.Call(i, args, call.Loc)
+ if err != nil {
+ e := errors.NewError(errors.CallError, call.Loc, "error calling", fun.Name()).(errors.Error)
+ e.Err = err
+ return nil, e
+ }
+
+ // clear return if there's no return value
+ if fun.Ret() == nil {
+ return nil, nil
+ }
+
+ return nil, errors.NewError(errors.NoReturn, i.last, "no return value in", fun.Name(), "expected", fun.Ret().String())
+}
+
+func (i *Interpreter) fun(v ast.Func) (any, error) {
+ callable := Function{fun: v, closure: i.env}
+ i.env = i.env.Set(v.Name.Value, Var{callable, v.Name.Loc})
+ return nil, nil
+}
+
+func (i *Interpreter) strct(v ast.Struct) (any, error) {
+ pEnv := i.env
+ sEnv := NewEnv(pEnv)
+ i.env = sEnv
+ defer func() {
+ i.env = pEnv
+ strct := Struct{name: v.Name.Value, env: sEnv}
+ i.env = i.env.Set(v.Name.Value, Var{Value: strct, Loc: v.Name.Loc})
+ }()
+
+ _, err := i.blockNoEnv(v.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ return nil, nil
+}
+
+func (i *Interpreter) getExpr(v ast.GetExpr) (any, error) {
+ obj, err := i.evaluate(v.Object)
+ if err != nil {
+ return nil, err
+ }
+
+ if val, ok := obj.(Struct); ok {
+ name := v.Expr.(ast.Variable)
+ if name.Index == nil {
+ expr, _ := val.env.Get(name.Name.Value, true)
+ return expr.Value, nil
+ }
+
+ arr, _ := val.env.GetArray(name.Name.Value, true)
+ index, err := i.evaluate(name.Index)
+ if err != nil {
+ return nil, err
+ }
+ return arr[index.(ast.Number)], nil
+ }
+
+ // shouldn't happen
+ return nil, errors.NewError(errors.Unimplemented, v.Resolve().Value.Loc, "unimplemented get-expr")
+}
+
+func (i *Interpreter) returnStmt(v ast.Return) (any, error) {
+ var val any
+
+ if v.Value != nil {
+ // could be a tail call we could optimize
+ if call, ok := v.Value.(ast.Call); ok {
+ i.fError = v.Error
+ panic(&PanicJump{typ: PanicJumpTailCall, value: call, loc: v.Loc})
+ }
+
+ var err error
+ val, err = i.evaluate(v.Value)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ i.fError = v.Error
+ panic(&PanicJump{typ: PanicJumpReturn, value: val, loc: v.Loc})
+}
+
+func (i *Interpreter) continueStmt(v ast.Continue) (any, error) {
+ panic(&PanicJump{typ: PanicJumpContinue, loc: v.Loc})
+}
+
+func (i *Interpreter) breakStmt(v ast.Break) (any, error) {
+ panic(&PanicJump{typ: PanicJumpBreak, loc: v.Loc})
+}
+
+func (i *Interpreter) exprList(v ast.ExprList) (any, error) {
+ // XXX: should we set last?
+ vals := make([]any, len(v.Exprs))
+ for n, expr := range v.Exprs {
+ val, err := i.evaluate(expr)
+ if err != nil {
+ return nil, err
+ }
+ vals[n] = val
+ }
+ return vals, nil
+}