open InfoSyntax
open Batteries

type t = (string, datatype) Map.t

let empty_ctxt = Map.empty

(*Looks up the type of var in the context ctxt. Returns a
Some ty if var has an associated type. Otherwise, returns None. *)
let lookup_ctxt ctxt (var:string) : datatype option = 
   if Map.mem var ctxt then Some (Map.find var ctxt) else None

(*Updates the context ctxt to have the variable var
associated with the type ty. Returns the updated context.*)
let update_ctxt ctxt (var:string) (ty:datatype) =
	Map.add var ty ctxt


(* Returns the string representation of ty. *)
let rec string_of_ty (ty:datatype) : string =
	match ty with 
	| Int_T -> "int"
	| Bool_T -> "bool"
  | List_T (d) -> (string_of_ty d)^" list"
  | Pair_T (d1, d2) -> (string_of_ty d1)^"*"^(string_of_ty d2)
  | Rec_T (d1, d2) -> (string_of_ty d1)^"->"^(string_of_ty d2)
  | Temp -> "a'"

(* Converts ty, bad_ty, and e to the proper types
before generating a type error with the Error interface.*)
let handle_error (ty: datatype) (bad_ty: datatype) (e: span_exp) =
	let sty = string_of_ty ty in
	let sbad = string_of_ty bad_ty in
	let se = Printing.string_of_exp (to_exp e) in
	let span = e.info in
	Error.type_error sty sbad se span

(*Used for comparing types to make Temp type polymorphic*)
let rec compare_type e1 e2 : bool=
  if e1 == e2 then true else
  match (e1, e2) with
  | (List_T d1, List_T d2)  -> (compare_type d1 d2)
  | (Pair_T (d1, d2), Pair_T (d3, d4)) -> (compare_type d1 d3) && (compare_type d2 d4)
  | (Rec_T (d1, d2), Rec_T (d3, d4)) -> (compare_type d1 d3) && (compare_type d2 d4)
  | (_, Temp) -> true
  | (Temp, _) -> true
  | _ -> false

(* Checks that op typechecks when applied to e1
and e2.*)
let check_binary (op:Syntax.operator) (e1:span_exp) (e2:span_exp) 
	loop ctxt: datatype =
	let ty1 = loop ctxt e1 in
	let ty2 = loop ctxt e2 in
	match op with
	| Plus | Minus | Times ->
		(match ty1, ty2 with 
		| Int_T, Int_T -> Int_T
		| Int_T, _ -> handle_error Int_T ty2 e2
		| _ -> handle_error Int_T ty1 e1) 

	| Less | LessEq ->
		(match ty1, ty2 with
		| Bool_T, Bool_T -> Bool_T
    | Int_T, Int_T -> Bool_T
		| Bool_T, _ -> handle_error Bool_T ty2 e2
		| _ -> handle_error Int_T ty1 e1)

	| _ -> Error.op_error (Span.extend e1.info e2.info)

(* Uses the ctxt to recursively typecheck the expression e. 
If e (and its subexpressions) properly typechecks, returns a ty. *)
let rec check_loop ctxt (e:span_exp) : datatype = 
	match e.e with
	| Var x ->
		(match (lookup_ctxt ctxt x) with 
		| None -> Error.unbound_error x e.info
		| Some ty -> ty)
	| Constant(Int _) -> Int_T
	| Constant(Bool _) -> Bool_T

	| Op (e1, op, e2) ->
		check_binary op e1 e2 check_loop ctxt
	| If (e1,e2,e3) ->
		let ty1 = check_loop ctxt e1 in
		let ty2 = check_loop ctxt e2 in
		let ty3 = check_loop ctxt e3 in
		(match ty1 with 
		| Bool_T ->
			if (compare_type ty2 ty3 == false) then handle_error ty2 ty3 e3
			else ty2
		| _ -> handle_error Bool_T ty1 e1)
	| Let (x, e1, e2) -> 
		let ty = check_loop ctxt e1 in
		let updated = update_ctxt ctxt x ty in
		check_loop updated e2
  | Pair (e1, e2) ->
    let ty = check_loop ctxt e1 in
	  let ty2 = check_loop ctxt e2 in
    Pair_T (ty, ty2)
  | Fst (e1) ->
    let ty = check_loop ctxt e1 in
    (match ty with
    | Pair_T(a, b) -> a
    | _ -> handle_error (Pair_T(Temp, Temp)) ty e1)
  | Snd (e1) ->
    let ty = check_loop ctxt e1 in
    (match ty with
    | Pair_T(a, b) -> b
    | _ -> handle_error (Pair_T(Temp, Temp)) ty e1)
  | EmptyList -> List_T(Temp)
  | Cons(e1, e2) ->
    let ty = check_loop ctxt e1 in
	  let ty2 = check_loop ctxt e2 in
    (match ty2 with
    | List_T ty3 -> if compare_type ty3 ty then List_T ty3 else
      handle_error (ty3) ty e1
    | _ -> handle_error (ty2) ty e2)
  | Match (e1, e2, v1, v2, e3) ->
    let ty = check_loop ctxt e1 in
    let ty2 = check_loop ctxt e2 in
    (match ty with
    | List_T t ->
      let updated = update_ctxt (update_ctxt ctxt v1 t) v2 ty in
      let ty3 = check_loop updated e3 in
      if compare_type ty2 ty3 then ty2 else handle_error ty2 ty3 e3
    | _ -> handle_error (List_T Temp) ty e1)
   | Rec (d1, v1, v2, e1) ->
      (match d1 with
      | Rec_T (d2, d3)->
      let updated = update_ctxt (update_ctxt ctxt v1 d1) v2 d2 in
      let ty = check_loop updated e1 in
      if compare_type ty d3 then d1 else handle_error d1 ty e1
      | _ -> handle_error (Rec_T(Temp, Temp)) d1 e)
   | App (e2, e1) ->
     let ty = check_loop ctxt e1 in
     let ty2 = check_loop ctxt e2 in
     (match ty2 with
     | Rec_T (d1, d2) -> if compare_type ty d1 then d2 else
       handle_error d1 ty e1
     | _ -> handle_error (Rec_T(Temp, Temp)) ty2 e2)

(* Typechecks e. Returns a unit. *)
let typecheck_exp (e:span_exp) =
	let _ = check_loop empty_ctxt e in
	()

