(*-----------------------------------------------------------------------
** Copyright (C) - Verimag.
** This file may only be copied under the terms of the GNU Library General
** Public License
**-----------------------------------------------------------------------
**
** File: oracle.ml
** Author: jahier@imag.fr
*)

open Var
open Value
open Ocaml2c


type var_type = string

(* ------------------------------------------------------------------------ *)
(* XXX duplicated from sut.ml, but i don't want to polute the Sut interface 
   as it is used by end-users.
*)

let (struct_array_to_pair_list : Ocaml2c.vn_t array -> (string * string) list) =
  fun a ->
    (Array.fold_right 
       (fun stru acc -> (stru.var_name, stru.var_type)::acc) a [])


(*********************************************************************)

(* exported *)
let (init : unit -> unit) =
  Ocaml2c.lurette__oracle_init

(* exported *)
let (get_input_var_name_and_type : unit ->  (string * string) list) =
  fun () -> 
    struct_array_to_pair_list
      (Ocaml2c.lurette__oracle_input_var_name_and_type_array
         (Ocaml2c.lurette__oracle_input_arg_nb()))

(* exported *)
let (get_output_var_name_and_type : unit ->  (string * string) list) =
  fun () -> 
    struct_array_to_pair_list
      (Ocaml2c.lurette__oracle_output_var_name_and_type_array
         (Ocaml2c.lurette__oracle_output_arg_nb()))


(*********************************************************************)

let oracle_i_vntl = get_input_var_name_and_type ()
let oracle_o_vntl = get_output_var_name_and_type ()

let sut_i_vntl = Sut.get_input_var_name_and_type ()
let sut_o_vntl = Sut.get_output_var_name_and_type ()

let (set_oracle_input: env_out -> env_in -> unit) =
  fun input output ->
    let n =
      (List.fold_left
	 (fun j (vn, t) ->
	    let value = try (Value.OfIdent.get input vn) with
	    Not_found -> (
		Printf.fprintf stderr
		  "internal error in Oracle.set_oracle_input\n  while searching input \"%s\" value\n" vn;
		exit 13
	    ) in
	    (match value with
		 B(b) -> Ocaml2c.lurette__oracle_set_val_bool j b
	       | N(I(i)) -> Ocaml2c.lurette__oracle_set_val_int j (Util.int_of_num i)
	       | N(F(f)) -> Ocaml2c.lurette__oracle_set_val_float j f
	    );
	    (j+1)
	 )
	 0
	 sut_i_vntl
      )
    in
    let _ =
      (List.fold_left
	 (fun j (vn, t) ->
	    let value = try (Value.OfIdent.get output vn) with
	    Not_found -> (
		Printf.fprintf stderr
		  "internal error in Oracle.set_oracle_input\n  while searching output \"%s\" value\n" vn;
		exit 13
	    ) in
	    (match value with
		B(b) -> Ocaml2c.lurette__oracle_set_val_bool j b
	      | N(I(i)) -> Ocaml2c.lurette__oracle_set_val_int j (Util.int_of_num i)
	      | N(F(f)) -> Ocaml2c.lurette__oracle_set_val_float j f
	    );
	    (j+1)
	 )
	 n
         sut_o_vntl
      )
    in
      ()


let (trie : env_out -> env_in -> bool * Value.OfIdent.t) =
  fun input output  ->
    
    Ocaml2c.lurette__oracle_restore_state ();
    set_oracle_input input output;
    Ocaml2c.lurette__oracle_step () ;

    let j = ref 0 in
    let outs = 
      (List.fold_left
	 (fun acc (vn, t) ->
	    let value =
	      match t with
		  "bool" ->
		    let b = Ocaml2c.lurette__oracle_get_val_bool !j in
		      B(b)
		| "int" ->
		    let i = Ocaml2c.lurette__oracle_get_val_int !j in
		      N(I(Num.num_of_int i))
		| "float" ->
	            let f = Ocaml2c.lurette__oracle_get_val_float !j in
	              N(F(f))
		| _  -> assert false
	    in
              incr j;
              Value.OfIdent.add acc (vn,value)
	 )
         Value.OfIdent.empty
	 oracle_o_vntl
      )
    in
      match oracle_o_vntl with
	  (vn, "bool")::_->
	    let res = Ocaml2c.lurette__oracle_get_val_bool 0 in
	      res,outs
        | _ -> failwith "*** The oracle first output ougth to be a Boolean.\n"
	
	
let (step : env_out -> env_in -> bool * Value.OfIdent.t) =
  fun input output ->    
    let res = trie input output in 
      Ocaml2c.lurette__oracle_save_state ();
      res



