(*-----------------------------------------------------------------------
** Copyright (C) - Verimag.
** This file may only be copied under the terms of the GNU Library General
** Public License
**-----------------------------------------------------------------------
**
** File: lurette_ocaml.ml
** Main author: jahier@imag.fr
*)

(** main for the Lurette that test ocaml programs *)

(*------------------------------------------------------------------------*)

open Util
open Exp
open Var
open List
open Command_line
open Value
open Prog

(*------------------------------------------------------------------------*)



let (options:Command_line.optionsT) = {
  step_by_step = None ;
  display_local_var = false ;
  display_sim2chro = true ;
  user_seed = None ;
  verb = false ;
  show_step = false ;
  help = false ;
  output = "lurette.rif" ;
  draw_all_formula = false;
  draw_all_vertices = false;
  compute_volume = false;
  step_mode = Lucky.StepInside;
  oracle = true ;
  pp = None;
  stdin = false
}

(*------------------------------------------------------------------------*)

let (print_vn_str : out_channel -> (string * string) list -> unit) =
  fun oc vl ->
    List.iter
      (fun (v, t) ->
	 output_string oc ("\n\t\"" ^ v ^ "\"\t of type " ^ t ^ ","))
      vl

(** [check_var_decl_consistency out_env in_sut out_sut] checks the
  consistency of variable declarations made in the environment and the
  Sut or the oracle.

  It raises an exception if the declarations are inconsistent.
*)


let (check_var_decl_consistency : Prog.state -> (Var.name * string) list ->
       (Var.name * string) list -> (Var.name * string) list ->
	 (Var.name * string) list -> unit) =
  fun state in_sut out_sut in_oracle out_oracle ->

    (* Retreiving the environment list of variables from the state *)
    let in_env_unsorted = state.s.in_vars in
    let out_env_unsorted = state.s.out_vars in
    let in_env1 = List.sort
	(fun v1 v2 -> compare (Var.name v1) (Var.name v2)) in_env_unsorted
    and out_env1 = List.sort
	(fun v1 v2 -> compare (Var.name v1) (Var.name v2)) out_env_unsorted
    in
    let in_env =
      List.map (fun v -> ((Var.name v), (Type.to_string (Var.typ v)))) in_env1
    and out_env =
      List.map (fun v -> ((Var.name v), (Type.to_string (Var.typ v)))) out_env1
    in

      (* In order to check the oracle variable decl *)
    let all_sut_var = List.sort
	(fun v1 v2 -> compare (fst v1) (fst v2)) (in_sut @ out_sut)
    in

      if
	not(list_are_equals out_env in_sut)
      then
	(
	  let diff1 = diff_list_as_set out_env in_sut in
	  let diff2 = diff_list_as_set in_sut out_env in
	    output_string stdout
	      "\n*** env outputs and sut inputs should be the same.\n";
	    if diff1 <> [] then
	      (
		print_string
		 "\n\nAppears in the sut inputs, but not in the env outputs:";
		print_vn_str stdout diff1
	      );
	    if diff2 <> [] then
	      (
		print_string
		 "\n\nAppears in the env outputs, but not in the sut inputs:";
		print_vn_str stdout diff2
	      );
	    flush stdout ;
	    exit 2
	)
      else if
	not(list_are_equals in_env out_sut)
      then
	let diff1 = diff_list_as_set in_env out_sut in
	let diff2 = diff_list_as_set out_sut in_env in
	  (
	    output_string stdout
	     "\n*** env inputs and sut outputs should be the same.\n";
	    if diff1 <> [] then
	      (
		print_string
		 "\n\nAppears in the sut outputs, but not in the env inputs:";
		print_vn_str stdout diff1
	      );
	    if diff2 <> [] then
	      (
		print_string
		 "\n\nAppears in the env inputs, but not in the sut outputs:";
		print_vn_str stdout diff2
	      );
	    flush stdout ;
	    exit 2
	  )
      else if
	not(list_are_equals all_sut_var in_oracle)
      then
	let diff1 = diff_list_as_set all_sut_var in_oracle in
	let diff2 = diff_list_as_set in_oracle all_sut_var in
	  (
	    output_string stdout
	     "\n*** sut inputs and outputs should be the same as oracle inputs.\n";
	    if diff1 <> [] then
	      (
		print_string
		 ("\n\nAppears in the oracle input, but not in the "^
		  "sut inputs and outputs:");
		print_vn_str stdout diff1
	      );
	    if diff2 <> [] then
	      (
	      print_string
		 ("\n\nAppears in the sut inputs and outputs, "^
		  "but not in the oracle inputs:");
		print_vn_str stdout diff2
	      );
	    flush stdout ;
	    exit 2
	  )
      else
	()
	

(*------------------------------------------------------------------------*)
	
let (struct_array_to_pair_list : Ocaml2c.vn_t array -> (string * string) list) =
  fun a ->
    (Array.fold_right
       (fun structure acc ->
	  (structure.var_name, structure.var_type)::acc)
       a
       []
    )


(* Get the variable number *)

let sut_input_arg_nb = Ocaml2c.lurette__sut_input_arg_nb ()
let sut_output_arg_nb = Ocaml2c.lurette__sut_output_arg_nb ()
let oracle_input_arg_nb = Ocaml2c.lurette__oracle_input_arg_nb ()
let oracle_output_arg_nb = Ocaml2c.lurette__oracle_output_arg_nb ()


(* Get the lists of var names and types *)
let sut_i_vntl =
  struct_array_to_pair_list
    (Ocaml2c.lurette__sut_input_var_name_and_type_array sut_input_arg_nb)

let sut_o_vntl =
  struct_array_to_pair_list
    (Ocaml2c.lurette__sut_output_var_name_and_type_array sut_output_arg_nb)

let oracle_i_vntl =
  struct_array_to_pair_list
    (Ocaml2c.lurette__oracle_input_var_name_and_type_array oracle_input_arg_nb)

let oracle_o_vntl =
  struct_array_to_pair_list
    (Ocaml2c.lurette__oracle_output_var_name_and_type_array oracle_output_arg_nb)

let arg_nb = (Array.length Sys.argv)

(* Get the list of input and output var *)
let input_list_ref = fst (split sut_i_vntl)
let output_list_ref = fst (split sut_o_vntl)


(* I defined mine because i need to know the seed that has been drawn by self_init. *)
let random_seed () =
  let () = Random.self_init () in
    Random.int 1073741823


(********************************************************************************)


let print_failure i o l t rif =
  output_string stdout "\n* sut inputs:\n\t" ;
  print_subst_list i stdout;
  output_string stdout "\n* sut outputs:\n\t" ;
  print_env_in o stdout;
  output_string stdout "\n* env locals:\n\t" ;
  print_subst_list l stdout;
  
  output_string rif "\n#oracle_failure at\n";
(* ICI *) assert (o <> Value.OfIdent.empty);
  Sim2chro.put_current_step_values
    rif
    t
    i
    o
    l
    options.display_local_var
    sut_o_vntl
    sut_i_vntl ;
  exit 1

let check_oracle inputs sut_outputs locals memory t rif state =
  if
    options.oracle
  then
    (* Tries the oracle `n*p'^nth times *)
    let (results: bool list) =
      List.map2
	(fun x y -> Oracle.trie x y)
	inputs
	sut_outputs
    in
      (*
	 Aborts if at least one of the pairs (input, sut_output) breaks the
	 oracle
      *)
      if
	(List.mem false results)
      then
	(
	  (* print inputs and outputs of all wrong tuple *)
	  output_string stdout (
	    "\n*** The oracle returned false" ^
	    " at step "  ^ (string_of_int t) ^
	    " with the following values:\n  ");
	
	  let rec print_failures li lo ll lr =
	    let i = List.hd li
	    and o = List.hd lo
	    and l = List.hd ll
	    and r = List.hd lr
	    in      
	      if (not r) then
		print_failure i o l t rif
	      else
		print_failures (List.tl li) (List.tl lo) (List.tl ll) (List.tl lr)
	  in
	    print_failures inputs sut_outputs locals results ;
	
	    output_string stdout "\n* pre:\n\t" ;
	    print_subst_list memory stdout;
	    flush stdout;
	    flush rif;
	    if options.display_sim2chro
	    then Sim2chro.call_sim2chro state options.output
	    else () ;	
	    output_string stdout (
	      "\n*** Lurette stops because the oracle returned "
	      ^ "false at step " ^ (string_of_int t) ^ ".\n");
	    false
	)
      else
	true
  else
    true

(**************************************************************************)


let string_of_node n = n

let l_average = ref 0.0
let step_cpt = ref 0

let rec (test_manager : unit -> unit) =
  fun _ ->
    Array.iter (fun x -> output_string stderr (x ^ " ")) Sys.argv;
    output_string stderr "\n";
    flush stderr;
    try
      if
	(arg_nb < 7)
      then
	if
	  arg_nb >= 2 & ((    (Sys.argv.(1) = "--help")
			   || (Sys.argv.(1) = "-help")
			   || (Sys.argv.(1) = "-h")   ))
	then
	  (
	    output_string stdout usage ;
	    exit 0
	  )
	else
	  (
	    output_string stderr usage ;
	    exit 1
	  )
      else
	let s = (cmd_line_string_to_int Sys.argv.(1)
		   ("*** int expected as first argument. " ^
		    Sys.argv.(1) ^ " is not an int.") )

	and p = (cmd_line_string_to_int Sys.argv.(2)
		   ("*** int expected as third argument. " ^
		    Sys.argv.(2) ^ " is not an int.") )
	and k1 = (cmd_line_string_to_int Sys.argv.(3)
		   ("*** int expected as third argument. " ^
		    Sys.argv.(3) ^ " is not an int.") )
	and k2 = (cmd_line_string_to_int Sys.argv.(4)
		   ("*** int expected as third argument. " ^
		    Sys.argv.(4) ^ " is not an int.") )
	and k3 = (cmd_line_string_to_int Sys.argv.(5)
		   ("*** int expected as third argument. " ^
		    Sys.argv.(5) ^ " is not an int.") )
	in
	  if 
	    options.stdin 
	  then
	    (
	      print_string "\nNot yet implemented, sorry\n";
	      flush stdout;
	      exit 0
	    )
	  else
	    let state = main2 s p k1 k2 k3 in
	      if (options.step_by_step = None && options.display_sim2chro)
	      then Sim2chro.call_sim2chro state options.output;
	      
    with
	Failure(errmsg) ->
	  print_string errmsg;
	  flush stdout;
	  exit 2
      | e ->
	  print_string (Printexc.to_string e);
	  flush stdout;
	  exit 2

and
  (get_lurette_options: int -> int) =
  fun n ->
    try
      begin
	let opt = List.assoc Sys.argv.(n) Command_line.string_to_option in
	  match opt with
	      Step ->
		let str = (Sys.argv.(n+1)) in
		  options.step_by_step <- Some (cmd_line_string_to_int str
		    ("*** Error when calling lurette: an " ^
		     "integer is expected after the " ^
		     "option -step\n")) ;
		  n+2 
	    | NoStep -> options.step_by_step <- None ; (n+1)
		
	    | DisplayLocalVar   -> options.display_local_var <- true ; (n+1)
	    | NoDisplayLocalVar -> options.display_local_var <- false ; (n+1)
		
	    | Sim2chro  -> options.display_sim2chro <- true ; (n+1)
	    | NoSim2chro -> options.display_sim2chro <- false ; (n+1)

	    | StepInside -> options.step_mode <- Lucky.StepInside ; (n+1)
	    | StepEdges -> options.step_mode <- Lucky.StepEdges ; (n+1)
	    | StepVertices -> options.step_mode <- Lucky.StepVertices ; (n+1)

	    | Stdin -> options.stdin <- true ; (n+1)

	    | NoOracle -> options.oracle <- false ; (n+1)
	    | Verbose  -> options.verb <- true; (n+1)
	    | ShowStep  -> options.show_step <- true ; (n+1)
	    | AllFormula -> options.draw_all_formula <- true ; (n+1)
	    | AllVertices -> options.draw_all_vertices <- true ; (n+1)
	    | Help   -> options.help <- true ; (n+1)
	    | Output ->
		let str = (Sys.argv.(n+1)) in
		  options.output <- str ;
		  n+2

	    | Seed ->
		let str = (Sys.argv.(n+1)) in
		  options.user_seed <- Some (cmd_line_string_to_int str
		    ("*** Error when calling lurette: an " ^
		     "integer is expected after the " ^
		     "option --with-seed\n")) ;
		  n+2
	    | Precision ->
		let str = (Sys.argv.(n+1)) in
		  Util.precision := (cmd_line_string_to_int str
		    ("*** Error when calling lurette: an " ^
		     "integer is expected after the " ^
		     "option --precision\n")) ;
		  Util.update_eps ();
		  n+2

	    | PP  ->
		let pp = (Sys.argv.(n+1)) in
		  if pp <> "" then options.pp <- Some pp ; 
		  (n+2)

	    | ComputeVolume ->
		options.compute_volume <- true ; (n+1)
      end
    with Not_found -> n
and
  (** Returns the environment file names given at top-level into a
    list of list.

    Also set the lurette command line options if any.
    *)
  (get_env_from_args: int -> string list -> string list) =
  fun n file_l ->
    if (n = arg_nb) then file_l
    else
      let m = get_lurette_options n in
	(* m > n iff Sys.argv.(n) is an option *)
	if
	  (m > n)
	then
	  get_env_from_args m file_l
	else
	  let arg = Sys.argv.(m) in
	    if
	      (arg = "x")
	    then
	      (* ignore x for backward compatibility *)
	      let env = Sys.argv.(n+1) in
                get_env_from_args (n+2) (env::file_l)
	    else
	      get_env_from_args (n+1) (arg::file_l)

and
  (main2 : int -> int -> int -> int -> int -> Prog.state) =
  fun s p k1 k2 k3 ->
      (* Clean up tables as non-reg assert stuff migth have filled them *)
    let _ =
      Formula_to_bdd.clear_all ()
    in
    let env_list = (get_env_from_args 6 []) in
      (* XXX LUTIN *)
    let state0 = LucProg.make_state options.pp env_list in

    let init_state_dyn =
      {
	memory = state0.d.memory;
	ctrl_state = state0.d.ctrl_state;
	input = state0.d.input;
	verbose = options.verb
      }
    in
    let init_state = {
      d = init_state_dyn ;
      s = state0.s
    }
    in

    let rif = open_out options.output in

    let local_var_name_and_type_list_unsorted0 =
      (* remove aliases *)
      fst (List.partition
	     (fun v -> (Var.alias v) = None) init_state.s.loc_vars)
    in
    let local_var_name_and_type_list_unsorted =
      List.map (fun v -> ((Var.name v), (Type.to_string (Var.typ v))))
	local_var_name_and_type_list_unsorted0
    in
    let local_var_name_and_type_list =
      Util.sort_list_string_pair local_var_name_and_type_list_unsorted
    in

    (* Initialisation of the random engine *)
    let seed =
      match (options.user_seed) with
	  None ->  (random_seed ())
	| Some seed -> seed
    in

      check_var_decl_consistency
	init_state sut_i_vntl sut_o_vntl oracle_i_vntl oracle_o_vntl;

      Random.init seed ;
      output_string stdout
	("\nThe random engine was initialized with the seed " ^
	 (string_of_int seed) ^  "\n ");
      flush stdout ;

      (* Initialisation of the sut and the oracle *)
      Ocaml2c.lurette__sut_init ();
      Ocaml2c.lurette__oracle_init ();

      (* Sim2chro *)
      output_string rif ("# seed = " ^ (string_of_int seed) ^ "\n");
      (match options.step_by_step with
	 | Some i -> 
	     step_cpt := i; (* so that it stops at the first step *)
	     if 
	       (Show_env.luc_to_dot init_state.d.ctrl_state []
		  ("environment" ^ (string_of_int (Hashtbl.hash Sys.argv)))
		  init_state.s.graph) = 0
	     then
	       Util.gv ("environment" ^ (string_of_int (Hashtbl.hash Sys.argv)) ^ ".ps");
	     
	     (Sim2chro.put_var_decl
		("lurette chronogram (" ^
		 (fold_left
		    (fun acc str ->
		       (acc ^ " " ^ str)) (hd env_list) (tl env_list)) ^ " )"
		)
		sut_i_vntl
		sut_o_vntl
		local_var_name_and_type_list
		stderr options.display_local_var);
	     flush stderr
	       
	 | None -> 
	     (
	       Sim2chro.put_var_decl
		("lurette chronogram (" ^
		 (fold_left (fun acc str -> (acc ^ " " ^ str)) "" env_list) ^ ")")
		sut_i_vntl
		sut_o_vntl
		local_var_name_and_type_list
		rif
		options.display_local_var
	     )
      );
      (* Initializing Dd's libs. *)
      
      (* selecting the draw mode *)
      if
	options.compute_volume
      then
	Solver.set_fair_mode ()
      else
	Solver.set_efficient_mode ();
      
      (* Initializing the solution number table *)
      !Solver.init_snt ();
      
      let
	final_state =
	if
	  not (options.help)
	then
	  main_loop 1 s p k1 k2 k3 rif (Hashtbl.create 0) init_state
	else
	  init_state
      in
	flush stdout;
	flush rif;
	close_out rif;
	final_state
and
  main_loop t s p k1 k2 k3 rif input state =

  let _ =
    if options.show_step then
      output_string stdout ("\n--- step " ^ (string_of_int t) ^ ":\n");
    if state.d.verbose then
      List.iter
        (fun n ->
	   output_string stdout ("current nodes:" ^ (string_of_node n) ^ "\n "))
	(List.flatten state.d.ctrl_state);
    flush stdout
  in
  let _ =
    if (k1 = 0 && k2 = 0 && k3 = 0) then () else
      (
	let num_thickness =
	  (k1, k2,
	   if options.draw_all_vertices then Thickness.All else Thickness.AtMost k3)
	in
	let (outputs_loc: (env_out * env_loc) list)
	  = Lucky.env_try ((options.draw_all_formula, p), num_thickness) input state
	in
	
	(* Extracts the outputs and locals from the couple *)
	let (outputs, locals) = List.split outputs_loc in
	
	(* Tries the sut `n*p'^nth times *)
	let (inputs: env_in list) =
	  List.map (Sut.trie) outputs in
	
	let _ =
	  if
	    not (check_oracle outputs inputs locals state.d.memory t rif state)
	  then
	    exit 2;
	in
	let l = (List.length inputs) in
	  l_average := !l_average +. (float_of_int l)
      )
  in
    (* Performs the steps *)

  let (next_state, (output, loc)) = Lucky.env_step (options.step_mode) input state in
  let new_input = Sut.step output in

  let _ = check_oracle [output] [new_input] [loc] state.d.memory t rif next_state in

  let _ =
    if (options.oracle) then 
      let r = Oracle.step output new_input in
	if (not r) then
	  print_failure output new_input loc t rif
  in
 
  let str =
    match (options.step_by_step) with
	Some i -> 
	  let skip = (i - !step_cpt > 0) in
	    if (not skip 
(* 		&& state.d.ctrl_state <> next_state.d.ctrl_state *)
	       ) then
	      (
		let err_code = (Show_env.luc_to_dot 
				  next_state.d.ctrl_state
				  state.d.ctrl_state
				  ("environment" ^ (string_of_int (Hashtbl.hash Sys.argv)))
				  state.s.graph
			       )
		in
		  ()
	      );
(* ICI *) assert (new_input <> Value.OfIdent.empty);
	    Sim2chro.put_current_step_values
	      stdout
	      t
	      output
	      new_input
	      loc
	      options.display_local_var
	      sut_o_vntl
	      sut_i_vntl ;
	    if 
	      skip 
	    then
	      (
		incr(step_cpt); 
		" "
	      )
	    else
	      (
		step_cpt := 1;
		output_string stdout
		  (*  ZZZ this string it matched in xlurette *)
		  "\nOne more loop ? [type 's' to stop, `CR' to continue, or an integer to change the number of steps to skip.]\n";
		let str = read_line () in
		  try 
		    let i = Util.my_int_of_string str in
		      options.step_by_step <- Some i;
		      str
		  with _ -> 
		    str
	      )
	      
      | None -> 
	  (
(* ICI *) assert (new_input <> Value.OfIdent.empty);
	    Sim2chro.put_current_step_values
	     rif t output new_input loc
	     options.display_local_var
	     sut_o_vntl sut_i_vntl;
	    ""
	  )
  in

    flush rif;
    (* Decides whether to loop once more *)
    if
      ((str <> "s") && (s > t))
(*       ((str = "" || str = " ") && (s > t))   *)
    then
      main_loop (t+1) s p k1 k2 k3 rif new_input next_state
    else
      (
	if options.oracle then
	  output_string stdout
	    "\n ==> The test completed; no property has been violated.\n\n";
	print_string (
	  "      The Test Thickness average was " ^
	  (string_of_float (1.0 +. !l_average /. (float_of_int t))) ^ "\n");
	flush stdout;
	next_state
      )



let lurette_main _ = 
    test_manager ()

let _ =
  Callback.register "lurette_main" lurette_main
;;



(* To to able to use ocamldebug *)
(* let _ = lurette_main (); print_string "** WARNING : Debug mode\n";;  *)
