diff --git a/regression_ppx/dune b/regression_ppx/dune index 152f73a7..f702628f 100644 --- a/regression_ppx/dune +++ b/regression_ppx/dune @@ -148,13 +148,14 @@ (modules test013mutual) (package OCanren) (public_names -) - (flags - (:standard - ;-dsource - ; - )) (preprocess - (pps OCanren-ppx.ppx_distrib GT.ppx_all -- -new-typenames -pretty)) + (pps + OCanren-ppx.ppx_distrib + GT.ppx_all + OCanren-ppx.ppx_fresh + -- + -new-typenames + -pretty)) (libraries OCanren OCanren.tester)) (executables @@ -162,13 +163,30 @@ (modules test014diseq) (package OCanren) (public_names -) - (flags - (:standard - ;-dsource - ; - )) (preprocess - (pps OCanren-ppx.ppx_fresh OCanren-ppx.ppx_deriving_reify GT.ppx_all)) + (pps + GT.ppx_all + OCanren-ppx.ppx_fresh + OCanren-ppx.ppx_tester + OCanren-ppx.ppx_repr + OCanren-ppx.ppx_deriving_reify + -- + -pretty)) + (libraries OCanren OCanren.tester)) + +(executables + (names test015diseq) + (modules test015diseq) + (package OCanren) + (public_names -) + (preprocess + (pps + OCanren-ppx.ppx_fresh + OCanren-ppx.ppx_deriving_reify + GT.ppx_all + OCanren-ppx.ppx_repr + -- + -pretty)) (libraries OCanren OCanren.tester)) (cram @@ -306,3 +324,12 @@ %{project_root}/ppx/pp_ocanren_all.exe test014diseq.ml test014diseq.exe)) + +(cram + (package OCanren) + (applies_to test015) + (deps + (package OCanren-ppx) + %{project_root}/ppx/pp_ocanren_all.exe + test015diseq.ml + test015diseq.exe)) diff --git a/regression_ppx/test015.t b/regression_ppx/test015.t new file mode 100644 index 00000000..3f690473 --- /dev/null +++ b/regression_ppx/test015.t @@ -0,0 +1,35 @@ + $ ./test015diseq.exe + rel, 1 answer { + hd1 = _.12 + hd2 = _.14 + tl2 = _.15 + 10: { 0: [| 10 =/= _.11 |] } + + 11: { 0: [| 11 =/= boxed 0 <_.12, _.13> |] } + + 15: { 0: [| 14 =/= _.12, 15 =/= _.13 |] } + + hd2 === 1 + 15: { 0: [| 14 =/= _.12, 15 =/= _.13 |] } + + tl2 === [] + 13: { 0: [| 12 =/= int<1>, 13 =/= int<0> |] } + + 47 + 13: { 0: [| 12 =/= int<1>, 13 =/= int<0> |] } + + } + fun _ -> + fresh x + ((Std.list Fun.id [!< x; !< x]) =/= + (Std.list Fun.id [!< (!! 1); !< (!! 2)])), 1 answer { + q=_.10; + } + fun q -> + fresh (x y) (trace_index "x" x) (trace_index "y" y) ((x % y) === q) + ((x % y) =/= (Std.list Fun.id [!! 1; x])) + (y === (Std.list Fun.id [!! 2])) success, 1 answer { + x = _.11 + y = _.12 + q=[_.11; 2]; + } diff --git a/regression_ppx/test015diseq.ml b/regression_ppx/test015diseq.ml new file mode 100644 index 00000000..3461120f --- /dev/null +++ b/regression_ppx/test015diseq.ml @@ -0,0 +1,87 @@ +open OCanren +open Tester + +let debug_line line = + debug_var !!1 OCanren.reify (function _ -> + Format.printf "%d\n%!" line; + success) +;; + +let trace_index msg var = + debug_var var OCanren.reify (function + | [ Var (n, _) ] -> + Printf.printf "%s = _.%d\n" msg n; + success + | _ -> assert false) +;; + +let trace fmt = + Format.kasprintf + (fun s -> + debug_var !!1 OCanren.reify (function _ -> + Format.printf "%s\n%!" s; + success)) + fmt +;; + +let rel list1 = + let open OCanren.Std in + fresh + (list2 hd1 tl1 hd2 tl2) + (trace_index "hd1" hd1) + (trace_index "hd2" hd2) + (trace_index "tl2" tl2) + (list1 =/= list2) + trace_diseq + (list1 === hd1 % tl1) + trace_diseq + (list2 === hd2 % tl2) + trace_diseq + (trace " hd2 === 1") + (hd2 === !!1) + trace_diseq + (trace " tl2 === []") + (tl2 === nil ()) + trace_diseq + (hd1 === !!1) + (debug_line __LINE__) + trace_diseq + (tl1 === nil ()) (* crashes here *) + (debug_line __LINE__) +;; + +(* let () = [%tester run_r [%show GT.int GT.list] (Std.List.reify reify) 1 (fun q -> rel q)] *) +let () = run_r (Std.List.reify reify) ([%show: GT.int logic Std.List.logic] ()) 1 q qh (REPR rel) + +let () = + let open Std in + run_r + (Std.List.reify reify) + ([%show: GT.int logic Std.List.logic] ()) + 1 + q + qh + (REPR (fun _ -> fresh x (Std.list Fun.id [ ! + fresh + (x y) + (trace_index "x" x) + (trace_index "y" y) + (x % y === q) + (x % y =/= Std.list Fun.id [ !!1; x ]) + (* trace_diseq *) + (y === Std.list Fun.id [ !!2 ]) + (* trace_diseq *) + success)) +;; diff --git a/src/core/Core.ml b/src/core/Core.ml index a6f77735..93dbf6dd 100644 --- a/src/core/Core.ml +++ b/src/core/Core.ml @@ -795,4 +795,9 @@ module Tabling = let reify_in_empty reifier x = let st = State.empty () in - reifier (State.env st) x \ No newline at end of file + reifier (State.env st) x + +let trace_diseq : goal = fun st -> + Format.printf "%a\n%!" Disequality.pp (State.constraints st); + success st + diff --git a/src/core/Core.mli b/src/core/Core.mli index 63cf6bcf..74fc1902 100644 --- a/src/core/Core.mli +++ b/src/core/Core.mli @@ -320,4 +320,6 @@ module PrunesControl : sig end (** Runs reifier on empty state. Useful to debug execution order *) -val reify_in_empty: ('a, 'b) Reifier.t -> 'a -> 'b \ No newline at end of file +val reify_in_empty: ('a, 'b) Reifier.t -> 'a -> 'b + +val trace_diseq: goal diff --git a/src/core/Disequality.ml b/src/core/Disequality.ml index 0c989254..ae3c28e3 100644 --- a/src/core/Disequality.ml +++ b/src/core/Disequality.ml @@ -19,6 +19,11 @@ (* to avoid clash with Std.List (i.e. logic list) *) module List = Stdlib.List +let log fmt = + if false + then Format.kasprintf (Format.printf "%s\n%!") fmt + else Format.ifprintf Format.std_formatter fmt + module Answer = struct module S = Set.Make(Term) @@ -91,6 +96,9 @@ module Disjunct : (* Disjunction.t is a set of single disequalities joint by disjunction *) type t + + val pp : Format.formatter -> t -> unit + (* [make env subst x y] creates new disjunct from the disequality [x =/= y] *) val make : Env.t -> Subst.t -> 'a -> 'a -> t @@ -118,9 +126,19 @@ module Disjunct : struct type t = Term.t Term.VarMap.t - let update t = - ListLabels.fold_left ~init:t - ~f:(let open Subst.Binding in fun acc {var; term} -> + let pp ppf d = + if Term.VarMap.is_empty d then Format.fprintf ppf "" + else + Format.fprintf ppf "[| "; + Term.VarMap.iteri (fun i k v -> + if i<>0 then Format.fprintf ppf ", "; + Format.fprintf ppf " @[%d =/= %s@]" k.Term.Var.index (Term.show v) + ) d; + Format.fprintf ppf " |]" + + let update : t -> _ -> t = fun init -> + ListLabels.fold_left ~init + ~f:(fun acc {Subst.Binding.var; term} -> if Term.VarMap.mem var acc then (* in this case we have subformula of the form (x =/= t1) \/ (x =/= t2) which is always SAT *) raise Disequality_fulfilled @@ -149,12 +167,32 @@ module Disjunct : | Fulfiled -> raise Disequality_fulfilled | Violated -> raise Disequality_violated - let rec recheck env subst t = + let rec recheck env subst (t: t): t = + (* log "Disjunct.recheck: %a" pp t; *) let var, term = Term.VarMap.max_binding t in + (* log " max bind index = %d" var.Term.Var.index; *) let unchecked = Term.VarMap.remove var t in + (* log " unchecked: %a" pp unchecked; *) match refine env subst (Obj.magic var) term with - | Fulfiled -> raise Disequality_fulfilled - | Refined delta -> update unchecked delta + | Fulfiled -> + raise Disequality_fulfilled + | Refined delta -> ( + (* When leading terms are reified into something new, we still need to + do whole unification, beacuse other pairs may need walking --- + (we postponed walking, so som einformation may be lost.) + See issue #173 + *) + (* log "Refined into: %a" (Format.pp_print_list Subst.Binding.pp) delta; *) + match Subst.unify_map env subst t with + | None -> + (* not unifiable --- always distinct *) + raise Disequality_fulfilled + | Some ([], _) -> raise Disequality_violated + | Some (bnds, _subst) -> + (* TODO(Kakadu): reconstruction of map from binding list could hurt performance *) + let rez = Subst.varmap_of_bindings bnds in + (* log "Disjunct.recheck returns %a" pp rez; *) + rez) | Violated -> if Term.VarMap.is_empty unchecked then raise Disequality_violated @@ -208,6 +246,8 @@ module Conjunct : val empty : t + val pp : Format.formatter -> t -> unit + val is_empty : t -> bool val make : Env.t -> Subst.t -> 'a -> 'a -> t @@ -236,6 +276,19 @@ module Conjunct : type t = Disjunct.t M.t + let pp ppf map = + if M.is_empty map + then Format.fprintf ppf "{}" + else + let idx = ref 0 in + Format.fprintf ppf "{ "; + M.iter (fun k v -> + if !idx <> 0 then Format.fprintf ppf " ,"; + Format.fprintf ppf "@[%d: %a@]" k Disjunct.pp v; + incr idx + ) map; + Format.fprintf ppf " }" + let empty = M.empty let is_empty = M.is_empty @@ -256,11 +309,14 @@ module Conjunct : ) t Term.VarMap.empty let recheck env subst t = - M.fold (fun id disj acc -> + (* log "Conjunct.recheck. %a" pp t; *) + let rez = M.fold (fun id disj acc -> try M.add id (Disjunct.recheck env subst disj) acc with Disequality_fulfilled -> acc - ) t M.empty + ) t M.empty in + (* log "rechecked = %a" pp rez; *) + rez let merge_disjoint env subst = M.union (fun _ _ _ -> @@ -351,6 +407,11 @@ type t = Conjunct.t Term.VarMap.t let empty = Term.VarMap.empty +let pp ppf : t -> unit = + Term.VarMap.iter (fun k v -> + Format.fprintf ppf "@[%d: %a@]@," k.Term.Var.index Conjunct.pp v + ) + (* merges all conjuncts (linked to different variables) into one *) let combine env subst cstore = Term.VarMap.fold (fun _ -> Conjunct.merge_disjoint env subst) cstore Conjunct.empty @@ -370,17 +431,19 @@ let add env subst cstore x y = | Disequality_violated -> None let recheck env subst cstore bs = - let helper var cstore = + let helper var cstore : t = try let conj = Term.VarMap.find var cstore in let cstore = Term.VarMap.remove var cstore in update env subst (Conjunct.recheck env subst conj) cstore + with Not_found -> cstore in try let cstore = ListLabels.fold_left bs ~init:cstore - ~f:(let open Subst.Binding in fun cstore {var; term} -> + ~f:(fun cstore {Subst.Binding.var; term} -> let cstore = helper var cstore in + (* log "cstore = %a" pp cstore; *) match Env.var env term with | Some u -> helper u cstore | None -> cstore diff --git a/src/core/Disequality.mli b/src/core/Disequality.mli index b21ad044..41ac1692 100644 --- a/src/core/Disequality.mli +++ b/src/core/Disequality.mli @@ -51,3 +51,5 @@ module Answer : end val reify : Env.t -> Subst.t -> t -> 'a -> Answer.t list + +val pp: Format.formatter -> t -> unit diff --git a/src/core/Subst.ml b/src/core/Subst.ml index 4c324a2e..5e1962d8 100644 --- a/src/core/Subst.ml +++ b/src/core/Subst.ml @@ -47,10 +47,17 @@ module Binding = let hash {var; term} = Hashtbl.hash (Term.Var.hash var, Term.hash term) - let pp ppf { var ; term } = - Format.fprintf ppf "%a -> %a" Term.pp (Obj.repr var) Term.pp term + let pp ppf {var; term} = + Format.fprintf ppf "{ var.idx = %d; term=%s }" var.Term.Var.index (Term.show term) end +let varmap_of_bindings : Binding.t list -> Term.t Term.VarMap.t = + Stdlib.List.fold_left (fun (acc: _ Term.VarMap.t) Binding.{var;term} -> + assert (not (Term.VarMap.mem var acc)); + Term.VarMap.add var term acc + ) + Term.VarMap.empty + type t = Term.t Term.VarMap.t let empty = Term.VarMap.empty @@ -154,6 +161,11 @@ let extend ~scope env subst var term = exception Unification_failed +let log fmt = + if false + then Format.kasprintf (Format.printf "%s\n%!") fmt + else Format.ifprintf Format.std_formatter fmt + let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) env subst x y = (* The idea is to do the unification and collect the unification prefix during the process *) let extend var term (prefix, subst) = @@ -161,6 +173,7 @@ let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) env subst x y = (Binding.({var; term})::prefix, subst) in let rec helper x y acc = + (* log "unify '%s' and '%s'" (Term.show x) (Term.show y); *) let open Term in fold2 x y ~init:acc ~fvar:(fun ((_, subst) as acc) x y -> @@ -192,6 +205,15 @@ let apply env subst x = Obj.magic @@ ~fvar:(fun v -> Term.repr v) ~fval:(fun x -> Term.repr x) +let unify_map env subst map = + let vars, terms = + Term.VarMap.fold (fun v term acc -> (v :: fst acc, term :: snd acc)) map ([],[]) + in + (* log "var = %s" (Term.show (Obj.magic (apply env subst vars))); *) + (* log "terms = %s" (Term.show (Obj.magic (apply env subst terms))); *) + unify env subst (Obj.magic vars) (Obj.magic terms) + + let freevars env subst x = Env.freevars env @@ apply env subst x diff --git a/src/core/Subst.mli b/src/core/Subst.mli index 6ba66d9d..a06d8e21 100644 --- a/src/core/Subst.mli +++ b/src/core/Subst.mli @@ -32,6 +32,8 @@ module Binding : val pp: Format.formatter -> t -> unit end +val varmap_of_bindings: Binding.t list -> Term.t Term.VarMap.t + type t val pp: Format.formatter -> t -> unit @@ -67,6 +69,8 @@ val freevars : Env.t -> t -> 'a -> Term.VarSet.t *) val unify : ?subsume:bool -> ?scope:Term.Var.scope -> Env.t -> t -> 'a -> 'a -> (Binding.t list * t) option +val unify_map: Env.t -> t -> Term.t Term.VarMap.t -> (Binding.t list * t) option + val merge_disjoint : Env.t -> t -> t -> t (* [merge env s1 s2] merges two substituions *) diff --git a/src/core/Term.ml b/src/core/Term.ml index b3c42e49..daa8079a 100644 --- a/src/core/Term.ml +++ b/src/core/Term.ml @@ -90,6 +90,10 @@ module VarMap = match f (try Some (find k m) with Not_found -> None) with | Some x -> add k x m | None -> remove k m + + let iteri f m = + let i = ref 0 in + iter (fun k v -> f !i k v; incr i) m end type t = Obj.t diff --git a/src/core/Term.mli b/src/core/Term.mli index 0aabc171..a03ccf33 100644 --- a/src/core/Term.mli +++ b/src/core/Term.mli @@ -67,6 +67,9 @@ module VarMap : include Map.S with type key = Var.t val update : key -> ('a option -> 'a option) -> 'a t -> 'a t + + val iteri: (int -> key -> 'a -> unit) -> 'a t -> unit + end (* [t] type of untyped OCaml term *)