Module ExtValues


Require Import Coqlib.
Require Import Integers.
Require Import Values.
Require Import Floats ExtFloats.
Require Import Lia.

Open Scope Z_scope.

Definition abs_diff (x y : Z) := Z.abs (x - y).
Definition abs_diff2 (x y : Z) :=
  if x <=? y then y - x else x - y.
Lemma abs_diff2_correct :
  forall x y : Z, (abs_diff x y) = (abs_diff2 x y).
Proof.
  intros.
  unfold abs_diff, abs_diff2.
  unfold Z.leb.
  pose proof (Z.compare_spec x y) as Hspec.
  inv Hspec.
  - rewrite Z.abs_eq; lia.
  - rewrite Z.abs_neq; lia.
  - rewrite Z.abs_eq; lia.
Qed.

Inductive shift1_4 : Type :=
| SHIFT1 | SHIFT2 | SHIFT3 | SHIFT4.

Definition z_of_shift1_4 (x : shift1_4) :=
  match x with
  | SHIFT1 => 1
  | SHIFT2 => 2
  | SHIFT3 => 3
  | SHIFT4 => 4
  end.

Definition shift1_4_of_z (x : Z) :=
  if Z.eq_dec x 1 then Some SHIFT1
  else if Z.eq_dec x 2 then Some SHIFT2
  else if Z.eq_dec x 3 then Some SHIFT3
  else if Z.eq_dec x 4 then Some SHIFT4
  else None.

Lemma shift1_4_of_z_correct :
  forall z,
    match shift1_4_of_z z with
    | Some x => z_of_shift1_4 x = z
    | None => True
    end.
Proof.
  intro. unfold shift1_4_of_z.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  trivial.
Qed.

Definition int_of_shift1_4 (x : shift1_4) :=
  Int.repr (z_of_shift1_4 x).

Definition is_bitfield stop start :=
  (Z.leb start stop)
    && (Z.geb start Z.zero)
    && (Z.ltb stop Int.zwordsize).

Definition extfz stop start v :=
  if is_bitfield stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vint w =>
      Vint (Int.shru (Int.shl w (Int.repr (Z.sub Int.zwordsize stop'))) (Int.repr (Z.sub Int.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.


Definition extfs stop start v :=
  if is_bitfield stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vint w =>
      Vint (Int.shr (Int.shl w (Int.repr (Z.sub Int.zwordsize stop'))) (Int.repr (Z.sub Int.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.

Definition zbitfield_mask stop start :=
  (Z.shiftl 1 (Z.succ stop)) - (Z.shiftl 1 start).

Definition bitfield_mask stop start :=
  Vint(Int.repr (zbitfield_mask stop start)).

Definition bitfield_maskl stop start :=
  Vlong(Int64.repr (zbitfield_mask stop start)).

Definition insf stop start prev fld :=
  let mask := bitfield_mask stop start in
  if is_bitfield stop start
  then
    Val.or (Val.and prev (Val.notint mask))
           (Val.and (Val.shl fld (Vint (Int.repr start))) mask)
  else Vundef.

Definition is_bitfieldl stop start :=
  (Z.leb start stop)
    && (Z.geb start Z.zero)
    && (Z.ltb stop Int64.zwordsize).

Definition extfzl stop start v :=
  if is_bitfieldl stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vlong w =>
      Vlong (Int64.shru' (Int64.shl' w (Int.repr (Z.sub Int64.zwordsize stop'))) (Int.repr (Z.sub Int64.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.


Definition extfsl stop start v :=
  if is_bitfieldl stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vlong w =>
      Vlong (Int64.shr' (Int64.shl' w (Int.repr (Z.sub Int64.zwordsize stop'))) (Int.repr (Z.sub Int64.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.

Definition insfl stop start prev fld :=
  let mask := bitfield_maskl stop start in
  if is_bitfieldl stop start
  then
    Val.orl (Val.andl prev (Val.notl mask))
            (Val.andl (Val.shll fld (Vint (Int.repr start))) mask)
  else Vundef.

Fixpoint highest_bit (x : Z) (n : nat) : Z :=
  match n with
  | O => 0
  | S n1 =>
    let n' := Z.of_N (N_of_nat n) in
    if Z.testbit x n'
    then n'
    else highest_bit x n1
  end.

Definition int_highest_bit (x : int) : Z :=
  highest_bit (Int.unsigned x) (31%nat).


Definition int64_highest_bit (x : int64) : Z :=
  highest_bit (Int64.unsigned x) (63%nat).

Definition val_shrx (v1 v2: val): val :=
  match v1, v2 with
  | Vint n1, Vint n2 =>
     if Int.ltu n2 (Int.repr 31)
     then Vint(Int.shrx n1 n2)
     else Vundef
  | _, _ => Vundef
  end.

Definition val_shrxl (v1 v2: val): val :=
  match v1, v2 with
  | Vlong n1, Vint n2 =>
     if Int.ltu n2 (Int.repr 63)
     then Vlong(Int64.shrx' n1 n2)
     else Vundef
  | _, _ => Vundef
  end.

Remark modulus_fits_64: Int.modulus < Int64.max_unsigned.
Proof.
  compute.
  trivial.
Qed.

Remark unsigned64_repr :
  forall i,
    -1 < i < Int.modulus ->
    Int64.unsigned (Int64.repr i) = i.
Proof.
  intros i H.
  destruct H as [Hlow Hhigh].
  apply Int64.unsigned_repr.
  split. { lia. }
  pose proof modulus_fits_64.
  lia.
Qed.
  
Theorem divu_is_divlu: forall v1 v2 : val,
    Val.divu v1 v2 =
    match Val.divlu (Val.longofintu v1) (Val.longofintu v2) with
    | None => None
    | Some q => Some (Val.loword q)
    end.
Proof.
  intros.
  destruct v1; cbn; trivial.
  destruct v2; cbn; trivial.
  destruct i as [i_val i_range].
  destruct i0 as [i0_val i0_range].
  cbn.
  unfold Int.eq, Int64.eq, Int.zero, Int64.zero.
  cbn.
  rewrite Int.unsigned_repr by (compute; split; discriminate).
  rewrite (Int64.unsigned_repr 0) by (compute; split; discriminate).
  rewrite (unsigned64_repr i0_val) by assumption.
  destruct (zeq i0_val 0) as [ | Hnot0]; cbn; trivial.
  f_equal. f_equal.
  unfold Int.divu, Int64.divu. cbn.
  rewrite (unsigned64_repr i_val) by assumption.
  rewrite (unsigned64_repr i0_val) by assumption.
  unfold Int64.loword.
  rewrite Int64.unsigned_repr.
  reflexivity.
  destruct (Z.eq_dec i0_val 1).
  {subst i0_val.
   pose proof modulus_fits_64.
   rewrite Zdiv_1_r.
   lia.
  }
  destruct (Z.eq_dec i_val 0).
  { subst i_val. compute.
    split;
    intro ABSURD;
    discriminate ABSURD. }
  assert ((i_val / i0_val) < i_val).
  { apply Z_div_lt; lia. }
  split.
  { apply Z_div_pos; lia. }
  pose proof modulus_fits_64.
  lia.
Qed.
  
Theorem modu_is_modlu: forall v1 v2 : val,
    Val.modu v1 v2 =
    match Val.modlu (Val.longofintu v1) (Val.longofintu v2) with
    | None => None
    | Some q => Some (Val.loword q)
    end.
Proof.
  intros.
  destruct v1; cbn; trivial.
  destruct v2; cbn; trivial.
  destruct i as [i_val i_range].
  destruct i0 as [i0_val i0_range].
  cbn.
  unfold Int.eq, Int64.eq, Int.zero, Int64.zero.
  cbn.
  rewrite Int.unsigned_repr by (compute; split; discriminate).
  rewrite (Int64.unsigned_repr 0) by (compute; split; discriminate).
  rewrite (unsigned64_repr i0_val) by assumption.
  destruct (zeq i0_val 0) as [ | Hnot0]; cbn; trivial.
  f_equal. f_equal.
  unfold Int.modu, Int64.modu. cbn.
  rewrite (unsigned64_repr i_val) by assumption.
  rewrite (unsigned64_repr i0_val) by assumption.
  unfold Int64.loword.
  rewrite Int64.unsigned_repr.
  reflexivity.
  assert((i_val mod i0_val) < i0_val).
  apply Z_mod_lt.
  lia.
  split.
  { apply Z_mod_lt.
    lia. }
  pose proof modulus_fits_64.
  lia.
Qed.

Remark if_zlt_0_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt 0 Int.half_modulus) then x else y) = x.
Proof.
  reflexivity.
Qed.

Remark if_zlt_mone_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt (Int.unsigned Int.mone) Int.half_modulus) then x else y) = y.
Proof.
  reflexivity.
Qed.

Remark if_zlt_min_signed_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt (Int.unsigned (Int.repr Int.min_signed))
                     Int.half_modulus)
    then x
     else y) = y.
Proof.
  reflexivity.
Qed.

Lemma repr_unsigned64_repr:
  forall x, Int.repr (Int64.unsigned (Int64.repr x)) = Int.repr x.
Proof.
  intros.
  apply Int.eqm_samerepr.
  unfold Int.eqm.
  unfold Zbits.eqmod.
  pose proof (Int64.eqm_unsigned_repr x) as H64.
  unfold Int64.eqm in H64.
  unfold Zbits.eqmod in H64.
  destruct H64 as [k64 H64].
  change Int64.modulus with 18446744073709551616 in *.
  change Int.modulus with 4294967296.
  exists (-4294967296 * k64).
  set (y := Int64.unsigned (Int64.repr x)) in *.
  rewrite H64.
  clear H64.
  lia.
Qed.


Lemma big_unsigned_signed:
  forall x,
    (Int.unsigned x >= Int.half_modulus) ->
    (Int.signed x) = (Int.unsigned x) - Int.modulus.
Proof.
  destruct x as [xval xrange].
  intro BIG.
  unfold Int.signed, Int.unsigned in *. cbn in *.
  destruct (zlt _ _).
  lia.
  trivial.
Qed.


Lemma Z_quot_le: forall a b,
    0 <= a -> 1 <= b -> Z.quot a b <= a.
Proof.
  intros a b Ha Hb.
  destruct (Z.eq_dec b 1) as [Hb1 | Hb1].
  { (* b=1 *)
    subst.
    rewrite Z.quot_1_r.
    auto with zarith.
  }
  destruct (Z.eq_dec a 0) as [Ha0 | Ha0].
  { (* a=0 *)
    subst.
    rewrite Z.quot_0_l.
    auto with zarith.
    lia.
  }
  assert ((Z.quot a b) < a).
  {
    apply Z.quot_lt; lia.
  }
  auto with zarith.
Qed.


Require Import Coq.ZArith.Zquot.
Lemma Z_quot_pos_pos_bound: forall a b m,
    0 <= a <= m -> 1 <= b -> 0 <= Z.quot a b <= m.
Proof.
  intros.
  split.
  { rewrite <- (Z.quot_0_l b) by lia.
    apply Z_quot_monotone; lia.
  }
  apply Z.le_trans with (m := a).
  {
    apply Z_quot_le; tauto.
  }
  tauto.
Qed.
Lemma Z_quot_neg_pos_bound: forall a b m,
    m <= a <= 0 -> 1 <= b -> m <= Z.quot a b <= 0.
  intros.
  assert (0 <= - (a ÷ b) <= -m).
  {
    rewrite <- Z.quot_opp_l by lia.
    apply Z_quot_pos_pos_bound; lia.
  }
  lia.
Qed.

Lemma Z_quot_signed_pos_bound: forall a b,
    Int.min_signed <= a <= Int.max_signed -> 1 <= b ->
    Int.min_signed <= Z.quot a b <= Int.max_signed.
Proof.
  intros.
  destruct (Z_lt_ge_dec a 0).
  {
    split.
    { apply Z_quot_neg_pos_bound; lia. }
    { eapply Z.le_trans with (m := 0).
      { apply Z_quot_neg_pos_bound with (m := Int.min_signed); trivial.
        split. tauto. auto with zarith.
      }
      discriminate.
    }
  }
  { split.
    { eapply Z.le_trans with (m := 0).
      discriminate.
      apply Z_quot_pos_pos_bound with (m := Int.max_signed); trivial.
      split. lia. tauto.
    }
    { apply Z_quot_pos_pos_bound; lia.
    }
  }
Qed.

Lemma Z_quot_signed_neg_bound: forall a b,
    Int.min_signed <= a <= Int.max_signed -> b < -1 ->
    Int.min_signed <= Z.quot a b <= Int.max_signed.
Proof.
  change Int.min_signed with (-2147483648).
  change Int.max_signed with 2147483647.
  intros.

  replace b with (-(-b)) by auto with zarith.
  rewrite Z.quot_opp_r by lia.
  assert (-2147483647 <= (a ÷ - b) <= 2147483648).
  2: lia.
  
  destruct (Z_lt_ge_dec a 0).
  {
    replace a with (-(-a)) by auto with zarith.
    rewrite Z.quot_opp_l by lia.
    assert (-2147483648 <= - a ÷ - b <= 2147483647).
    2: lia.
    split.
    {
      rewrite Z.quot_opp_l by lia.
      assert (a ÷ - b <= 2147483648).
      2: lia.
      {
        apply Z.le_trans with (m := 0).
        rewrite <- (Z.quot_0_l (-b)) by lia.
        apply Z_quot_monotone; lia.
        discriminate.
      }
    }
    assert (- a ÷ - b < -a ).
    2: lia.
    apply Z_quot_lt; lia.
  }
  {
    split.
    { apply Z.le_trans with (m := 0).
      discriminate.
      rewrite <- (Z.quot_0_l (-b)) by lia.
      apply Z_quot_monotone; lia.
    }
    { apply Z.le_trans with (m := a).
      apply Z_quot_le.
      all: lia.
    }
  }
Qed.

Lemma sub_add_neg :
  forall x y, Val.sub x y = Val.add x (Val.neg y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int.sub_add_opp.
Qed.

Lemma neg_mul_distr_r :
  forall x y, Val.neg (Val.mul x y) = Val.mul x (Val.neg y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int.neg_mul_distr_r.
Qed.


Lemma negl_mull_distr_r :
  forall x y, Val.negl (Val.mull x y) = Val.mull x (Val.negl y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int64.neg_mul_distr_r.
Qed.

Definition addx sh v1 v2 :=
  Val.add v2 (Val.shl v1 (Vint sh)).

Definition addxl sh v1 v2 :=
  Val.addl v2 (Val.shll v1 (Vint sh)).

Definition revsubx sh v1 v2 :=
  Val.sub v2 (Val.shl v1 (Vint sh)).

Definition revsubxl sh v1 v2 :=
  Val.subl v2 (Val.shll v1 (Vint sh)).

Definition minf v1 v2 :=
  match v1, v2 with
  | (Vfloat f1), (Vfloat f2) => Vfloat (ExtFloat.min f1 f2)
  | _, _ => Vundef
  end.

Definition maxf v1 v2 :=
  match v1, v2 with
  | (Vfloat f1), (Vfloat f2) => Vfloat (ExtFloat.max f1 f2)
  | _, _ => Vundef
  end.

Definition minfs v1 v2 :=
  match v1, v2 with
  | (Vsingle f1), (Vsingle f2) => Vsingle (ExtFloat32.min f1 f2)
  | _, _ => Vundef
  end.

Definition maxfs v1 v2 :=
  match v1, v2 with
  | (Vsingle f1), (Vsingle f2) => Vsingle (ExtFloat32.max f1 f2)
  | _, _ => Vundef
  end.

Definition invfs v1 :=
  match v1 with
  | (Vsingle f1) => Vsingle (ExtFloat32.inv f1)
  | _ => Vundef
  end.

Definition triple_op_float f v1 v2 v3 :=
  match v1, v2, v3 with
  | (Vfloat f1), (Vfloat f2), (Vfloat f3) => Vfloat (f f1 f2 f3)
  | _, _, _ => Vundef
  end.

Definition triple_op_single f v1 v2 v3 :=
  match v1, v2, v3 with
  | (Vsingle f1), (Vsingle f2), (Vsingle f3) => Vsingle (f f1 f2 f3)
  | _, _, _ => Vundef
  end.

Definition fmaddf := triple_op_float (fun f1 f2 f3 => Float.fma f2 f3 f1).
Definition fmaddfs := triple_op_single (fun f1 f2 f3 => Float32.fma f2 f3 f1).

Definition fmsubf := triple_op_float (fun f1 f2 f3 => Float.fma (Float.neg f2) f3 f1).
Definition fmsubfs := triple_op_single (fun f1 f2 f3 => Float32.fma (Float32.neg f2) f3 f1).