
open LutExe
open LtopArg

type vars = (string * Data.t) list


(* returns a \ b *)
let list_minus a b =  List.filter (fun v -> not (List.mem v b)) a

(* returns a U b *)
let list_union a b =
  List.fold_left (fun acc x -> if (List.mem x acc) then acc else x::acc) a b

(* I defined mine because i need to know the seed that has been drawn by self_init. *)
let random_seed () =
  let () = Random.self_init () in
    Random.int 10000000

(* to be able to dump cov info if the exec is stopped by a ctrl-c. *)
let cov_ref = ref None
let gnuplot_pid_ref = ref None
let gnuplot_oc = ref None

(* Returns luciole io if necessary *)
let (check_compat : vars -> vars -> vars -> vars -> vars -> vars -> 
                    int * (vars * vars) option) =
  fun env_in env_out sut_in sut_out oracle_in oracle_out -> 
    (* cf lurette.set_luciole_mode_if_necessary to add a call to luciole *)
    
    let missing_sut_in = list_minus sut_in env_out
    and missing_env_in = list_minus env_in sut_out
    and missing_oracle_in = list_minus oracle_in (sut_out @env_out)  in
    let luciole_out = list_union missing_sut_in missing_env_in in 
    let luciole_in = list_minus (env_out@sut_out) luciole_out in 
      (*     let luciole_in = [] in  *)
      
    let vars_to_string vars = 
      String.concat "," (List.map (fun (n,t) -> n^":"^(Data.type_to_string t)) vars) 
    in
      if missing_sut_in <> [] then (
        let missing_str = vars_to_string missing_sut_in in
          Printf.printf "Some variables are missing in input of the SUT: %s\n" missing_str
      ) ;
      if missing_env_in <> [] then (
        let missing_str = vars_to_string missing_env_in in
          Printf.printf "Some variables are missing in input of lutin: %s\n" missing_str
      );
      if luciole_out <> [] then (
        Printf.printf "try with luciole!\n";
        0, Some(luciole_in,luciole_out)
      ) 
      else if missing_oracle_in <> [] then (
        let missing_str = vars_to_string missing_oracle_in in
          Printf.printf "Some variables are missing in input of the oracle: %s\n"
                        missing_str;
          2,None
      ) 
      else (
        if List.mem ("Step") (fst(List.split luciole_in)) then (
          Printf.printf
            "*** You cannot use the name 'Step' for a variable with lurette, sorry.\n";
          flush stdout;
          2,None
        ) else (
          Printf.eprintf "RP Variables are compatible.\n";
          flush stderr;
          0, if args.luciole_mode then Some(luciole_in, ["Step",Data.Bool]) else None
        )
      )

type ctx = Event.t
type e = Event.t
open RdbgPlugin
let (make_rp_list : reactive_program list -> 
      vars list * vars list * (string -> unit) list * 
        (Data.subst list -> Data.subst list) list * 
        (Data.subst list -> ctx -> (Data.subst list -> ctx -> Event.t) ->
         Event.t) list * Data.subst list list * Data.subst list list) =
  fun rpl -> 
    let add_init init (a,b,c,d,e) = (a,b,c,d,e,init,init) in
    let aux rp = 
      let plugin =
        match rp with
(*           | LustreV6(prog,node) -> add_init [] (LustreRun.make_v6 prog node) *)
          | LustreV6(args) -> Lv6Run.make args
          | LustreV4(prog,node) -> LustreRun.make_v4 prog node
          | LustreEc(prog)      -> LustreRun.make_ec prog
          | LustreEcExe(prog)   -> LustreRun.make_ec_exe prog
          | Socket(addr, port)  -> LustreRun.make_socket addr port
          | SocketInit(addr, port) -> LustreRun.make_socket_init addr port 
          | Ocaml(cmxs) -> OcamlRun.make_ocaml cmxs
          | Lutin(args) -> LutinRun.make args
      in
      let ins, outs, kill, step, step_dbg, initin, initout =
        plugin.inputs,plugin.outputs,plugin.kill,plugin.step,plugin.step_dbg,
        plugin.init_inputs,plugin.init_outputs
      in
      let step = if args.debug_ltop then 
        let (string_of_subst : Data.subst -> string) =
          fun (str, v) -> str ^ "<-" ^ (Data.val_to_string Util.my_string_of_float v)
        in
        let sl2str sl = String.concat "," (List.map string_of_subst sl) in
          (fun sli  -> 
             let slo = step sli in
               Printf.eprintf "[%s] step(%s) = (%s) \n" 
                 (reactive_program_to_string rp) (sl2str sli) (sl2str slo);
               flush stderr;
               slo)
      else
        step 
      in
        ins, outs, kill, step, step_dbg, initin, initout     
    in
      Util.list_split7 (List.map aux rpl)
    

type cov_opt = 
    NO (* NoOracle *) 
  | OO (* OracleOnly *) 
  | OC of Coverage.t
exception OracleError of string
exception SutStop of cov_opt

(* Transform a map on a function list into CPS *)
let (step_dbg_sl : 
  (Data.subst list -> ctx -> 
   (Data.subst list -> ctx -> Event.t) -> Event.t) list -> 
      's list -> 'ctx  -> ('s list -> 'e) -> 'e) = 
  fun step_dbg_sl_l sl ctx cont -> 
    (* ouch! Celle-la est chevelue...  
       La difficulté, c'est de passer un 'List.map step' en CPS.
       Suis-je aller au plus simple ? En tout cas j'ai réussit :)
    *)
    let rec (iter_step  : 
               ('s list -> 'ctx  -> ('s list -> 'ctx  -> 'e) -> 'e) list -> 
              's list list -> 's list -> 'e) = 
      fun stepl res_stepl sl ->
        match stepl with
          | [] -> cont (List.flatten (res_stepl))
          | step::stepl ->
             step sl ctx (fun res_sl ctx -> iter_step stepl (res_sl::res_stepl) sl)
    in
      iter_step step_dbg_sl_l [] sl 



let (start : unit -> Event.t) =
  fun () ->
    (* Get sut info (var names, step func, etc.) *)
    let _ = if args.debug_ltop then LustreRun.debug := args.debug_ltop in
    let sut_in_l, sut_out_l, sut_kill_l, sut_step_sl_l, sut_step_dbg_sl_l, 
      sut_init_in_l, sut_init_out_l = make_rp_list args.suts 
    in
    let sut_kill msg = List.iter (fun f -> f msg) sut_kill_l in
    let sut_init_in = List.flatten sut_init_in_l in
    let sut_init_out = List.flatten sut_init_out_l in

    (* Get oracle info (var names, step func, etc.)*)
    let oracle_in_l, oracle_out_l, oracle_kill_l, oracle_step_sl_l, 
        oracle_step_dbg_sl_l, _, _ =
      make_rp_list args.oracles
    in
    let oracle_kill msg = List.iter (fun f -> f msg) oracle_kill_l in
    
    (* Get env info (var names, step func, etc.)*)
    let env_in_l, env_out_l, env_kill_l, env_step_sl_l, env_step_dbg_sl_l, 
      env_init_in_l, env_init_out_l = make_rp_list args.envs 
    in
    let env_kill msg = List.iter (fun f -> f msg) env_kill_l in
    let _env_init_in = Util.rm_dup (List.flatten env_init_in_l) in
    let _env_init_out = Util.rm_dup (List.flatten env_init_out_l) in

    let vars_to_string l = 
      String.concat "\n" (List.map (fun (vn,vt) -> 
        let vt = Data.type_to_string vt in
        Printf.sprintf "\t%s:%s" vn vt) l)
    in
    let flat_sut_in  = Util.rm_dup (List.flatten sut_in_l)
    and flat_sut_out = Util.rm_dup (List.flatten sut_out_l)
    and flat_env_in  = Util.rm_dup (List.flatten env_in_l)
    and flat_env_out = Util.rm_dup (List.flatten env_out_l)
    and flat_oracle_in  = Util.rm_dup (List.flatten  oracle_in_l)
    and flat_oracle_out = Util.rm_dup (List.flatten  oracle_out_l)
    in
    let _ = if args.verbose > 0 then
        let sut_input_str = vars_to_string flat_sut_in in
        let sut_output_str = vars_to_string flat_sut_out in
        let env_input_str = vars_to_string  flat_env_in in
        let env_output_str = vars_to_string flat_env_out in
        let oracle_input_str = vars_to_string flat_oracle_in in
        let oracle_output_str_l = List.map vars_to_string oracle_out_l in
        Printf.printf "sut input : \n%s\n"  sut_input_str;
        Printf.printf "sut output : \n%s\n" sut_output_str;
        Printf.printf "env input : \n%s\n"  env_input_str;
        Printf.printf "env output : \n%s\n"  env_output_str;
        Printf.printf "oracle(s) input : \n%s\n"  oracle_input_str;
        List.iter (fun str -> Printf.printf "oracle output : \n%s\n" str) 
                  oracle_output_str_l;
        flush stdout
    in
    (* Check var names and types compat. *)
    let res_compat, luciole_io_opt = 
      check_compat flat_env_in flat_env_out flat_sut_in flat_sut_out 
                   flat_oracle_in flat_oracle_out 
    in
    let (luciole_kill, luciole_step), luciole_outputs_vars =
      match luciole_io_opt with
        | None -> ((fun _ -> ()),(fun _ -> [])),[]
        | Some (luciole_in, luciole_out) -> 
          (LustreRun.make_luciole "./lurette_luciole.dro" luciole_in luciole_out), 
          luciole_out
    in
    let seed =
      match args.seed with
        | None -> random_seed ()
        | Some seed -> seed
    in
    let cov_init = (* XXX faut-il renommer les sorties de l'oracle ou raler en 
                     cas de clash ? *)
      if List.flatten oracle_out_l = [] then NO else 
        let oracle_out = List.flatten (List.map List.tl oracle_out_l) in
        if List.length oracle_out < 1 then OO else 
          let is_bool (_,t) = (t = Data.Bool) in
          let names = List.filter is_bool oracle_out in
          let names = fst (List.split names) in
          OC (Coverage.init names args.cov_file args.reset_cov_file)
    in
    let oc = open_out args.output in
    let sim2chro_oc = 
      if args.display_sim2chro then Util2.sim2chro_dyn () else open_out "/dev/null" 
    in
    let filter vals vars = 
      List.map (fun (n,t) -> n, 
        try List.assoc n vals
        with Not_found -> 
          let vars_str = String.concat ", " (List.map (fun (n,_) -> n) vals) in
          let msg = Printf.sprintf "Don't find %s in %s\n" n vars_str in
          failwith msg
      ) vars 
    in

    let rec check_oracles oracle_in_vals i oracle_out_l oracle_out_vals_l cov =
      let check_one_oracle = function 
        | [] -> assert false
        | (_, Data.B true)::tail -> tail
          
        | (_, Data.B false)::tail ->
          let msg = 
            match cov with 
                OC cov -> Coverage.dump_oracle_io oracle_in_vals tail cov 
              | _ -> ""
          in
          let msg = 
            Printf.sprintf "\n*** The oracle returned false at step %i\n%s" i msg 
          in
          print_string msg;
          flush stdout;
          if args.stop_on_oracle_error then raise (OracleError msg) else tail

        | (vn, vv)::_  -> 
          let vv = Data.val_to_string_type vv in
          let msg = Printf.sprintf 
            "The oracle first output should be a bool; but %s has type %s" vn vv in
          failwith msg
      in
      match cov with 
          NO -> NO
        | OO -> ignore (List.map check_one_oracle oracle_out_vals_l); OO
        | OC cov -> 
          let ll = List.map check_one_oracle oracle_out_vals_l in
          let cov =
            List.fold_left
              (fun cov other_oracle_out_vals ->
                Coverage.update_cov other_oracle_out_vals cov) cov ll
          in
          cov_ref := Some cov;
          OC cov
    in
    let update_cov cov = 
      match cov with 
        | NO -> ()
        | OO -> ()
        | OC cov -> 
          let str =
            String.concat ", " (List.map reactive_program_to_string args.oracles) 
          in
          Coverage.dump str args.output cov
    in
    (* The main loop *)
    let killem_all cov = 
      env_kill "quit\n";
      sut_kill "quit\n";
      luciole_kill "quit\n";
      oracle_kill "quit\n";
      close_out oc;
      close_out sim2chro_oc;
      update_cov cov;
    in
    let rec loop cov env_in_vals pre_env_out_vals ctx i () =
      if i > args.step_nb then (killem_all cov; raise (Event.End 0) ) 
      else
        let luciole_outs = luciole_step (env_in_vals@pre_env_out_vals) in
        let env_in_vals =  List.rev_append luciole_outs env_in_vals in
        if args.ldbg then (* XXX l'idéal serait de faire ce test à
                             l'exterieur de la boucle en passant la
                             fonction qui va bien selon le
                             mode. Apres tout, c'est l'un des
                             avantages du CPS... *)
          let edata = env_in_vals@pre_env_out_vals in
          let ctx = 
            { ctx with 
              Event.step = i;
              Event.name = "ltop";
              Event.depth = 1;
              Event.data = edata;
            }
          in
          let cont = loop2 cov env_in_vals pre_env_out_vals ctx i luciole_outs in
          step_dbg_sl env_step_dbg_sl_l env_in_vals ctx cont
        else
          let env_step_sl sl = List.flatten (List.map (fun f -> f sl) env_step_sl_l) in
          let env_out_vals = env_step_sl env_in_vals in
          loop2 cov env_in_vals pre_env_out_vals ctx i luciole_outs env_out_vals
    (*
      {
      step = i;
      data = [];
      next = (fun () -> loop2 cov env_in_vals pre_env_out_vals i luciole_outs env_out_vals);
      terminate = (fun () -> killem_all cov)
      }
    *)          
    and
        loop2 cov env_in_vals pre_env_out_vals ctx i luciole_outs env_out_vals =
      let env_out_vals = 
        try List.map (fun (v,vt) -> v,List.assoc v env_out_vals) flat_env_out 
        with Not_found -> env_out_vals
      in
      let env_out_vals = luciole_outs @ env_out_vals in
      let sut_in_vals = filter env_out_vals flat_sut_in in        
      if args.ldbg then       
        let edata = sut_in_vals@ env_out_vals in
        let ctx = { ctx with 
          Event.step = i;
          Event.name = "ltop";
          Event.depth = 1;
          Event.data = edata;
        } 
        in
        let cont = 
          loop3 cov env_in_vals pre_env_out_vals env_out_vals ctx i luciole_outs
        in
        step_dbg_sl sut_step_dbg_sl_l sut_in_vals ctx cont
      else
        let sut_step_sl sl = List.flatten (List.map (fun f -> f sl) sut_step_sl_l) in
        let sut_out_vals = sut_step_sl sut_in_vals in
        loop3 cov env_in_vals pre_env_out_vals env_out_vals ctx i
              luciole_outs sut_out_vals
    and loop3 cov env_in_vals pre_env_out_vals env_out_vals ctx i 
              luciole_outs sut_out_vals =
      let sut_out_vals = 
        try List.map (fun (v,vt) -> v,List.assoc v sut_out_vals) flat_sut_out 
        with Not_found -> sut_out_vals
      in
      let oracle_in_vals = 
        if args.delay_env_outputs 
        then List.rev_append pre_env_out_vals sut_out_vals 
        else List.rev_append     env_out_vals sut_out_vals 
      in
      let oracle_in_vals = List.rev_append luciole_outs oracle_in_vals in
      let oracle_in_vals = filter oracle_in_vals flat_oracle_in in
      let oracle_out_vals_l = List.map (fun f -> f oracle_in_vals) oracle_step_sl_l in

      (*       let oracle_out_vals = List.flatten oracle_out_vals_l in *)
      let oracle_out_vals_l = 
        try List.map2 
              (fun oracle_out oracle_out_vals -> 
                List.map (fun (v,vt) -> v,List.assoc v oracle_out_vals) oracle_out
              ) 
              oracle_out_l
              oracle_out_vals_l
        with Not_found -> oracle_out_vals_l
      in
      let print_val (vn,vv) = Data.val_to_string Util.my_string_of_float vv in
      Printf.fprintf oc "#step %d\n" i;

      if args.delay_env_outputs then (
        output_string oc (String.concat " " (List.map print_val (pre_env_out_vals)));
        output_string 
          sim2chro_oc (String.concat " " (List.map print_val (pre_env_out_vals)));
      )
      else (
        output_string oc (String.concat " " (List.map print_val env_out_vals));
        output_string sim2chro_oc (String.concat " " (List.map print_val env_out_vals));
      );
      output_string oc (if env_out_vals <> [] then " #outs " else "#outs ");
      output_string oc (String.concat " " (List.map print_val sut_out_vals));
      output_string oc "\n";
      List.iter (fun l -> 
        output_string oc "#oracle_outs ";
        output_string oc (String.concat " " (List.map print_val l));
        output_string oc "\n";
      ) oracle_out_vals_l;
      flush oc;

      output_string sim2chro_oc "#outs ";
      output_string sim2chro_oc (String.concat " " (List.map print_val sut_out_vals));
      output_string sim2chro_oc "\n";
      flush sim2chro_oc;
      
      if not args.go && args.display_gnuplot then (
        if  i = 0 then (
          let oc, pid = 
            GnuplotRif.terminal := GnuplotRif.Wxt;
            GnuplotRif.verbose := args.verbose>1;
            GnuplotRif.dynamic := true;
            GnuplotRif.rif_file := args.output;
            GnuplotRif.f ()
          in
          gnuplot_pid_ref := Some pid;
          gnuplot_oc := Some oc
        ) 
        else 
          (match !gnuplot_oc with
            | None -> ()
            | Some oc -> output_string oc "replot\n"; flush oc)
      );
      if args.ldbg then (
        let edata = sut_out_vals@env_out_vals@(List.flatten oracle_out_vals_l) in
        let term () = 
          (match !gnuplot_pid_ref with
            | None -> ()
            | Some pid ->
              print_string "Killing gnuplot...\n"; flush stdout;
              Unix.kill pid Sys.sigkill;
              gnuplot_oc := None;
              gnuplot_pid_ref := None); 
          killem_all cov
        in
        let enb = ctx.Event.nb in
        let ctx = { ctx with 
          Event.nb = ctx.Event.nb+1;
          Event.step = i;
          Event.name = "ltop";
          Event.depth = 1;
          Event.data = edata;
          Event.terminate = term;
        } 
        in
        {
          Event.nb = enb;
          Event.step = i;
          Event.kind = Event.Ltop;
          Event.depth = 1;
          Event.data = edata;
          Event.name = "rdbg";
          Event.lang = "";
          Event.inputs=[];
          Event.outputs=[];
          Event.locals = [];
          Event.sinfo=None;
          Event.next = 
            (fun () ->
              loop (check_oracles oracle_in_vals i oracle_out_l oracle_out_vals_l cov) 
                sut_out_vals env_out_vals ctx (i+1) () 
            );
          Event.terminate = term
        }
      )
      else
        loop (check_oracles oracle_in_vals i oracle_out_l oracle_out_vals_l cov) 
          sut_out_vals env_out_vals ctx (i+1) () 
    in

    let loc = None in
    let _ = 
      if args.compute_volume then Solver.set_fair_mode () 
      else Solver.set_efficient_mode ();
      !Solver.init_snt ();
      Random.init seed;

      Rif.write oc ("# This is lurette Version " ^ Version.str ^
                      " (\"" ^Version.sha^"\")\n");
      Rif.write oc ("#seed "^(string_of_int seed)^"\n");

      RifIO.write_interface oc 
        (luciole_outputs_vars@flat_env_out) flat_sut_out loc (Some oracle_out_l);
      Rif.flush oc;

      RifIO.write_interface sim2chro_oc 
        (luciole_outputs_vars@flat_env_out) flat_sut_out loc (Some oracle_out_l);
      Rif.flush sim2chro_oc;
    in
    let ctx =
      {
        Event.nb = 1;
        Event.step = 1;
        Event.name = "ltop";
        Event.depth = 1;
        Event.inputs = [];
        Event.outputs = [];
        Event.locals = [];
        Event.data = [];
        Event.terminate = (fun () -> killem_all cov_init);
        Event.lang = "";
        Event.next = (fun () -> assert false);
        Event.kind = Event.Ltop;
        Event.sinfo = None;
      }
    in
    let (first_event : Event.t) =
      let res = 
        try 
          if res_compat = 0 then
            loop cov_init sut_init_out sut_init_in ctx 0 ()
          else 
            raise(Event.End res_compat) 
        with 
          | SutStop cov -> 
            print_string "The SUT stopped\n";
            flush stdout;
            update_cov cov;
            raise(Event.End 1)

          | OracleError str -> 
            print_string str;
            flush stdout;
            raise(Event.End 2)

          | Failure str -> 
            print_string ("Failure occured in lurette: "^str);
            flush stdout;
            raise(Event.End 2) 
          | Event.End i -> raise(Event.End (10*i))
          | e -> 
            print_string (Printexc.to_string e);
            flush stdout;
            raise(Event.End 2)
      in
      res
    in
    first_event

(* exported *)
let (clean_terminate : unit -> unit) =
  fun () -> 
    let str = String.concat ", " (List.map reactive_program_to_string args.oracles) in
    (match !cov_ref with 
      | None -> ()
      | Some cov -> Coverage.dump str args.output cov);
    (match !gnuplot_pid_ref with
      | None -> ()
      | Some pid ->
        print_string "Killing gnuplot...\n"; 
        flush stdout;
        Unix.kill pid Sys.sigkill;
        gnuplot_pid_ref := None
    )

    
