aboutsummaryrefslogtreecommitdiff
path: root/src/Compiler.hs
blob: 0adc4add48c32c435f159b6e84c52a2d82023812 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
module Compiler where

import qualified Ast as A
import Control.Monad.State
import Data.Either (rights)
import Env
import Error
import System.Environment (getEnv, getEnvironment)
import Text.Parsec (ParseError, SourcePos)
import Text.Read (Lexeme (String))

type CompState = (Env, [Error])

type CompResult = Either [Error] (Maybe A.Type)

startState :: CompState
startState = (emptyEnv, [])

-- | @foldlEither fn init xs@ folds left @xs@ applying a function @fn@ that
-- returns either, returning accumulated right and the collected lefts as a
-- list.
foldlEither :: (accr -> b -> Either accl accr) -> (accr, [accl]) -> [b] -> (accr, [accl])
foldlEither fn init xs =
  foldl
    ( \(r, l) x -> case fn r x of
        Left e -> (r, e : l)
        Right x -> (x, l)
    )
    init
    xs

-- | @addError error@ adds @error@ to the state and returns no type to allow
-- the compilation to continue.
addError :: Error -> State CompState CompResult
addError e = do
  (ev, errs) <- get
  put (ev, e : errs)
  return $ Right Nothing

-- | @typecheckCall args params@ resolves @args@ and compares it with @params@,
-- returning a a string describing an error or Nothing in case of type match.
typecheckCall :: [A.Expr] -> [A.Type] -> State CompState (Maybe String)
typecheckCall args params
  | length args /= length params = return $ Just "invalid number of arguments in function call"
  | length params == 0 = return $ Nothing
  | otherwise = do
      -- resolve all args types
      targs <- fmap rights $ traverse compile args
      case sequence targs of
        Just t ->
          if length t /= length params
            then -- there was an error in one argument
              return $ Nothing
            else
              if all (\(a, b) -> a == b) $ zip t params -- compare types
                then return $ Nothing -- all good!
                else return $ Just "type mismatch in function call" -- TODO: type description
        Nothing ->
          -- there was an error in on argument
          return $ Nothing

compile :: A.Expr -> State CompState CompResult
compile x = do
  case x of
    (A.Module name pos) -> return $ Right Nothing
    (A.Num _ _) -> return $ Right $ Just $ A.Type "u8" -- TODO: placeholder
    (A.BinOp _ a b) -> do
      l <- compile a
      r <- compile b
      return $ l -- TODO: placeholder
    (A.Func ident params ret body priv pos) -> do
      -- current env
      (ev, errs) <- get
      -- updated with the function
      (ev, errs) <-
        return $ case addSymUniq ev (ident, ftype, pos) of
          Left err -> (ev, err : errs)
          Right ev -> (ev, errs)
      -- with parameters
      (nev, errs) <- return $ foldlEither addSymUniq (addEnv ev, errs) params
      put (nev, errs)
      r <- compileAll body
      (_, errs) <- get
      -- store updated errors and the env with the function
      put (ev, errs)
      return $ Right $ Just ftype
      where
        ftype = A.toFuncType params ret
    (A.Call ident args pos) -> do
      r <- compile ident
      case r of
        p@(Right (Just (A.FuncType params rtyp))) -> do
          r <- typecheckCall args params
          case r of
            Just err -> addError $ Error err pos
            Nothing -> return $ Right rtyp
    (A.Return value pos) -> case value of
      Just v -> compile v
      Nothing -> return $ Right Nothing
    (A.Var ident pos) -> do
      (ev, errs) <- get
      case getSym ev ident of
        Just (_, t, _) -> return $ Right $ Just t
        Nothing -> addError $ Error ("undefined variable \"" ++ ident ++ "\"") pos

compileAll :: [A.Expr] -> State CompState CompResult
compileAll (x : xs) = do
  compile x
  compileAll xs
compileAll [] = do
  (_, errs) <- get
  case errs of
    [] -> return $ Right Nothing
    _ -> return $ Left errs