(*-----------------------------------------------------------------------
** Copyright (C) 2002 - Verimag.
** This file may only be copied under the terms of the GNU Library General
** Public License
**-----------------------------------------------------------------------
**
** File: control.ml
** Author: jahier@imag.fr
*)

module StringMap = struct
  include Map.Make(struct type t = string let compare = compare end)
end


(** Contains all the necessary information to say whether a variable
  is recoverable or not.

  More precisely, if a control variable [v] comes from a draw between
  2 values [min] and [max], then we store and maintain [min] and
  [max] (ie, we apply to it the same computation as [v]). A variable
  is recoverable if other values would have been possible; hence, we
  have:

    [v] is recoverable <=> [(v=0 and max >= 1) or (v>0 and min <=0)].

  XXX Normal draw.
*)
type recovery_info
(*   = (int * int) option *)
  = Inter of int * int | Never | Always

let (string_of_recov : recovery_info -> string) =
  fun r ->
    match r with
	None -> ""
      | Some(min, max) ->
	  (" [" ^ (string_of_int min) ^ ", " ^ (string_of_int max) ^ "]")

(* exported *)
type state = (int * recovery_info) StringMap.t

(* exported *)
let (compose_state : state -> state -> state) =
  fun st1 st2 ->
    (StringMap.fold
       (fun str v acc -> StringMap.add str v acc)
       st1
       st2)


(* fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b *)

let (state_size : state -> int) =
  fun st ->
    StringMap.fold (fun _ _ cpt -> cpt+1) st 0

(* exported *)
let (print_state : state -> unit) =
   fun st ->
     let (pp: string -> int * recovery_info -> unit) =
       fun id (i, recov) ->
	 output_string stderr ("\t" ^ id ^ " = ");
	 output_string stderr (string_of_int i);
	 output_string stderr (" (" ^ (string_of_recov recov) ^ ")\n")
     in
       output_string stderr ("Control state: \n");
       StringMap.iter pp st;
       flush stderr

(* exported *)
let (new_state : unit -> state) =
  fun _ ->
    StringMap.empty

(* exported *)
let (set : string  -> int -> state -> state) =
  fun id i st ->
    StringMap.add id (i, None) st

(* exported *)
let (set_between : string -> int -> int -> int -> state -> state) =
  fun id i min max st ->
    StringMap.add id (i, Some(min, max)) st


(* exported *)
let (draw_between : string -> int -> int -> state -> state) =
  fun id min max st ->
    let draw =
      assert(max - min + 1 > 0);
      min + (Random.int (max - min + 1)) 
    in
      StringMap.add id (draw, (Some(min, max))) st

(* exported *)
let (draw_gauss : string -> float -> float -> state -> state) =
  fun id m std st ->
    let f = Util.gauss_draw m std in
    let max_recovering = int_of_float (m +. 100. *. std) in
    let min_recovering = int_of_float (m -. 100. *. std) in
    let draw = int_of_float (if f -. (ceil f) < 0.5 then f else f +. 1.) in
      StringMap.add id (draw, Some(min_recovering, max_recovering)) st

(* exported *)
let (dec : string -> state -> state) =
  fun id st ->
    assert (StringMap.mem id st);
    ( match (StringMap.find id st) with
	  (i, None) ->
	    StringMap.add id (i-1, None) st
	| (i, Some(min, max)) ->
	    StringMap.add id (i-1, Some(min-1, max-1)) st
    )

(* exported *)
let (return : string -> state -> int) =
  fun id st ->
    try fst (StringMap.find id st)
    with _ ->
      print_state st;
      failwith ("*** " ^ id ^ " is an undefined control expression.\n")
(* exported *)
let (return_comp : string -> state -> int) =
  fun id st ->
    if (fst (StringMap.find id st)) > 0 then 0 else 1


(* exported *)
let (is_recoverable : string -> state -> bool) =
  fun id st ->
    let (v, r) = StringMap.find id st in
      match StringMap.find id st with
	  (_,None) -> false
	| (i, Some(min, max)) -> (i<=0 && max >= 1) || (i>0 && min <=0)



type expr = state -> state

let (compose_expr : expr -> expr -> expr) =
  fun e1 e2 ->
    (fun state -> e2 (e1 state))


type test =
  | EqExpr of number * number
  | GtExpr of number * number
  | GeExpr of number * number
and number =
  | VarExpr of string
  | IntExpr of int

let (ite : test -> expr -> expr -> state -> state) =
  fun test e1 e2 st ->
    let number_to_int st n =
      match n with
	  VarExpr(id) -> return id st
	| IntExpr(i) -> i
    in
    let bool =
      match test with
	  EqExpr(n1, n2) -> ((number_to_int st n1) = (number_to_int st n2))
	| GtExpr(n1, n2) -> ((number_to_int st n1) > (number_to_int st n2))
	| GeExpr(n1, n2) -> ((number_to_int st n1) >= (number_to_int st n2))
    in
      if bool then (e1 st) else (e2 st)
