diff --git a/src/core/Core.ml b/src/core/Core.ml index 99e6bbaa..c17e5a5e 100644 --- a/src/core/Core.ml +++ b/src/core/Core.ml @@ -60,81 +60,73 @@ module List = Stdlib.List module Answer : sig + (* [Answer.t] - a type that represents (untyped) answer to a query *) type t (* [make env t] creates the answer from the environment and term (with constrainted variables) *) val make : Env.t -> Term.t -> t - (* [lift env a] lifts the answer into different environment, replacing all variables consistently *) - val lift : Env.t -> t -> t - (* [env a] returns an environment of the answer *) val env : t -> Env.t - (* [unctr_term a] returns a term with unconstrained variables *) - val unctr_term : t -> Term.t - (* [ctr_term a] returns a term with constrained variables *) val ctr_term : t -> Term.t + (* [unctr_term a] returns a term with unconstrained variables *) + val unctr_term : t -> Term.t + (* [disequality a] returns all disequality constraints on variables in term as a list of bindings *) val disequality : t -> Subst.Binding.t list + (* [lift env a] lifts the answer into different environment, replacing all variables consistently *) + val lift : Env.t -> t -> t + (* [equal t t'] syntactic equivalence (not an alpha-equivalence) *) val equal : t -> t -> bool (* [hash t] hashing that is consistent with syntactic equivalence *) val hash : t -> int end = struct + type t = Env.t * Term.t let make env t = (env, t) let env (env, _) = env - let unctr_term (_, t) = - Term.map t - ~fval:(fun x -> Term.repr x) - ~fvar:(fun v -> Term.repr {v with Term.Var.constraints = []}) - let ctr_term (_, t) = t + let unctr_term (_, t) = Term.map t ~fval:Term.repr + ~fvar:(fun v -> Term.repr { v with Term.Var.constraints = [] }) + let disequality (env, t) = - let rec helper acc x = - Term.fold x ~init:acc - ~fval:(fun acc _ -> acc) - ~fvar:(fun acc var -> - ListLabels.fold_left var.Term.Var.constraints ~init:acc - ~f:(fun acc ctr_term -> - let ctr_term = Term.repr ctr_term in - let var = {var with Term.Var.constraints = []} in - let term = unctr_term @@ (env, ctr_term) in - let acc = Subst.(Binding.({var; term}))::acc in - helper acc ctr_term - ) - ) + let rec helper acc x = Term.fold x ~init:acc ~fval:(fun acc _ -> acc) + ~fvar:begin fun acc ctr_var -> + let var = { ctr_var with Term.Var.constraints = [] } in + ListLabels.fold_left ctr_var.Term.Var.constraints ~init:acc + ~f:begin fun acc ctr_term -> + let term = unctr_term (env, ctr_term) in + let acc = Subst.Binding.{ var ; term }::acc in + helper acc ctr_term + end + end in helper [] t let lift env' (env, t) = let vartbl = Term.VarTbl.create 31 in - let rec helper x = - Term.map x - ~fval:(fun x -> Term.repr x) - ~fvar:(fun v -> Term.repr @@ - try - Term.VarTbl.find vartbl v - with Not_found -> - let new_var = Env.fresh ~scope:Term.Var.non_local_scope env' in - Term.VarTbl.add vartbl v new_var; - {new_var with Term.Var.constraints = - List.map (fun x -> helper x) v.Term.Var.constraints - |> List.sort Term.compare - } - ) + let rec helper x = Term.map x ~fval:Term.repr + ~fvar:begin fun v -> Term.repr @@ + try Term.VarTbl.find vartbl v + with Not_found -> + let new_var = Env.fresh ~scope:Term.Var.non_local_scope env' in + Term.VarTbl.add vartbl v new_var ; + let ctr = List.map helper v.Term.Var.constraints |> List.sort Term.compare in + { new_var with Term.Var.constraints = ctr } + end in - (env', helper t) + env', helper t let check_envs_exn env env' = if Env.equal env env' then () else @@ -144,58 +136,62 @@ module Answer : check_envs_exn env env'; Term.equal t t' - let hash (env, t) = Term.hash t + let hash (_, t) = Term.hash t end module Prunes : sig + type rez = Violated | NonViolated - type ('a, 'b) reifier = ('a,'b) Reifier.t + type ('a, 'b) reifier = ('a, 'b) Reifier.t type 'b cond = 'b -> bool type t - val empty : t + val empty : t + val extend : t -> Term.VarTbl.key -> ('a, 'b) reifier -> 'b cond -> t + val recheck : t -> Env.t -> Subst.t -> rez val check_last : t -> Env.t -> Subst.t -> rez - val extend : t -> Term.VarTbl.key -> ('a, 'b) reifier -> 'b cond -> t end = struct + type rez = Violated | NonViolated - type ('a, 'b) reifier = ('a,'b) Reifier.t - type reifier_untyped = Obj.t + type ('a, 'b) reifier = ('a, 'b) Reifier.t type 'b cond = 'b -> bool + + type reifier_untyped = Obj.t type cond_untyped = Obj.t -> bool let make_untyped : ('a, 'b) reifier -> 'b cond -> reifier_untyped * cond_untyped = - fun a b -> Obj.magic (a,b) + fun a b -> Obj.magic (a, b) + + type t = (Term.t * (reifier_untyped * cond_untyped)) list - type t = (Obj.t * (reifier_untyped * cond_untyped)) list let empty = [] - exception Fail + let extend map term rr cond = + let new_item = make_untyped rr cond in + (Term.repr term, new_item) :: map - let check_last map env subst = - try - let (term, (reifier, checker)) = List.hd map in - let reifier : (_,_) reifier = Obj.obj reifier in - let reified = reifier env (Obj.magic @@ Subst.apply env subst term) in - if not (checker reified) then raise Fail; - NonViolated - with Not_found -> NonViolated - | Fail -> Violated + exception Fail let recheck (ps: t) env s = try - ps |> List.iter (fun (k, (reifier, checker)) -> - let reifier : (_,_) Reifier.t = Obj.obj reifier in + ps |> List.iter begin fun (k, (reifier, checker)) -> + let reifier : (_, _) Reifier.t = Obj.obj reifier in let reified = reifier env (Obj.magic @@ Subst.apply env s k) in if not (checker reified) then raise Fail - ); + end ; NonViolated with Fail -> Violated - let extend map term rr cond = - let new_item = make_untyped rr cond in - (Obj.repr term, new_item) :: map - + let check_last map env subst = + try + let (term, (reifier, checker)) = List.hd map in + let reifier : (_, _) reifier = Obj.obj reifier in + let reified = reifier env (Obj.magic @@ Subst.apply env subst term) in + if not (checker reified) then raise Fail ; + NonViolated + with Not_found -> NonViolated + | Fail -> Violated end type prines_control = @@ -244,6 +240,7 @@ module PrunesControl = struct ans ) end + (* let do_skip_prunes = ref false let prunes_checks_skipped = ref 0 @@ -253,8 +250,10 @@ let set_skip_prunes_count n = assert (n>0); max_prunes_skipped := n *) + module State = struct + type t = { env : Env.t ; subst : Subst.t @@ -283,65 +282,59 @@ module State = let new_scope st = {st with scope = Term.Var.new_scope ()} - let unify x y ({env; subst; ctrs; scope} as st) = - match Subst.unify ~scope env subst x y with - | None -> None - | Some (prefix, subst) -> - match Disequality.recheck env subst ctrs prefix with - | None -> None - | Some ctrs -> - let next_state = {st with subst; ctrs} in - if PrunesControl.is_exceeded () - then begin - let () = PrunesControl.reset_cur_counter () in - match Prunes.recheck (prunes next_state) env subst with - | Prunes.Violated -> None - | NonViolated -> Some next_state - end else begin -(* print_endline "check skipped";*) - let () = PrunesControl.incr () in - Some next_state - end - + let unify x y ({ env ; subst ; ctrs ; scope } as st) = + match Subst.unify ~scope env subst x y with + | None -> None + | Some (prefix, subst) -> + match Disequality.recheck env subst ctrs prefix with + | None -> None + | Some ctrs -> + let next_state = { st with subst ; ctrs } in + if PrunesControl.is_exceeded () + then begin + let () = PrunesControl.reset_cur_counter () in + match Prunes.recheck (prunes next_state) env subst with + | Prunes.Violated -> None + | NonViolated -> Some next_state + end else begin + (* print_endline "check skipped"; *) + let () = PrunesControl.incr () in + Some next_state + end - let diseq x y ({env; subst; ctrs; scope} as st) = + let diseq x y ({ env ; subst ; ctrs ; scope } as st) = match Disequality.add env subst ctrs x y with | None -> None | Some ctrs -> - match Prunes.recheck (prunes st) env subst with - | Prunes.Violated -> None - | NonViolated -> Some {st with ctrs} + match Prunes.recheck (prunes st) env subst with + | Prunes.Violated -> None + | NonViolated -> Some { st with ctrs } - (* returns always non-empty list *) - let reify x {env; subst; ctrs} = + (* always returns non-empty list *) + let reify x { env ; subst ; ctrs } = let answ = Subst.reify env subst x in match Disequality.reify env subst ctrs x with - | [] -> [Answer.make env answ] - | diseqs -> - ListLabels.map diseqs ~f:(fun diseq -> - let rec helper forbidden t = - Term.map t - ~fval:(fun x -> Term.repr x) - ~fvar:(fun v -> Term.repr @@ - if List.mem v.Term.Var.index forbidden then v - else - {v with Term.Var.constraints = - Disequality.Answer.extract diseq v - |> List.filter (fun dt -> - match Env.var env dt with - | Some u -> not (List.mem u.Term.Var.index forbidden) - | None -> true - ) - |> List.map (fun x -> helper (v.Term.Var.index::forbidden) x) - (* TODO: represent [Var.constraints] as [Set]; - * TODO: hide all manipulations on [Var.t] inside [Var] module; - *) - |> List.sort Term.compare - } - ) - in - Answer.make env (helper [] answ) - ) + | [] -> (* [Answer.make env answ] *) assert false + | diseqs -> ListLabels.map diseqs ~f:begin fun diseq -> + let rec helper forbidden t = Term.map t ~fval:Term.repr + ~fvar:begin fun v -> Term.repr @@ + if Term.VarSet.mem v forbidden then v + else { v with Term.Var.constraints = Disequality.Answer.extract diseq v + |> List.filter begin fun dt -> + match Env.var env dt with + | Some u -> not @@ Term.VarSet.mem u forbidden + | None -> true + end + |> List.map (fun x -> helper (Term.VarSet.add v forbidden) x) + (* TODO: represent [Var.constraints] as [Set]; + * TODO: hide all manipulations on [Var.t] inside [Var] module; + *) + |> List.sort Term.compare + } + end + in + Answer.make env @@ helper Term.VarSet.empty answ + end end let (!!!) = Obj.magic @@ -797,6 +790,6 @@ let reify_in_empty reifier x = reifier (State.env st) x let trace_diseq : goal = fun st -> - Format.printf "%a\n%!" Disequality.pp (State.constraints st); + Format.printf "%a\n%!" Disequality.pp (State.constraints st) ; success st diff --git a/src/core/Disequality.ml b/src/core/Disequality.ml index 3b5c481e..013404a5 100644 --- a/src/core/Disequality.ml +++ b/src/core/Disequality.ml @@ -26,6 +26,7 @@ let log fmt = module Answer = struct + module S = Set.Make(Term) (* answer is a conjunction of single disequalities, i.g. (x =/= 1 /\ y =/= 2); @@ -38,26 +39,26 @@ module Answer = let add env t var term = try let terms = S.add term @@ Term.VarMap.find var t in - Term.VarMap.add var terms @@ Term.VarMap.remove var t + Term.VarMap.add var terms t with Not_found -> Term.VarMap.add var (S.singleton term) t let mem env t var term = - try - S.mem term @@ Term.VarMap.find var t + try S.mem term @@ Term.VarMap.find var t with Not_found -> false let extract t v = - try S.elements @@ Term.VarMap.find v t with Not_found -> [] + try S.elements @@ Term.VarMap.find v t + with Not_found -> [] - let subsumed env t t' = + let subsumed env t = (* we should check that for each binding from [t'] there is * a binding in [t] that subsumes it; * Examples: * (x =/= _.0) <= (x =/= 1 /\ x =/= 2), but * (x =/= _.0) and (x =/= 1 /\ y =/= 2) are not ordered *) - Term.VarMap.for_all (fun var terms' -> + Term.VarMap.for_all @@ fun var terms' -> try let terms = Term.VarMap.find var t in S.for_all (fun term' -> @@ -66,7 +67,6 @@ module Answer = ) terms ) terms' with Not_found -> false - ) t' end exception Disequality_violated @@ -93,10 +93,10 @@ exception Disequality_fulfilled module Disjunct : sig + (* Disjunction.t is a set of single disequalities joint by disjunction *) type t - val pp : Format.formatter -> t -> unit (* [make env subst x y] creates new disjunct from the disequality [x =/= y] *) @@ -122,8 +122,8 @@ module Disjunct : val simplify : Env.t -> Subst.t -> t -> t option val reify : Env.t -> Subst.t -> t -> Subst.Binding.t list - end = - struct + end = struct + type t = Term.t Term.VarMap.t let pp ppf d = @@ -173,7 +173,7 @@ module Disjunct : (* log " max bind index = %d" var.Term.Var.index; *) let unchecked = Term.VarMap.remove var t in (* log " unchecked: %a" pp unchecked; *) - match refine env subst (Obj.magic var) term with + match refine env subst (Term.repr var) term with | Fulfiled -> raise Disequality_fulfilled | Refined delta -> ( @@ -199,10 +199,30 @@ module Disjunct : else recheck env subst unchecked + let is_relevant env subst t fv = + (* left those disjuncts that contains binding only for variables from [fv], + * otherwise it's possible to pick binding (x =/= t) from disjunct for + * variable [x] that is not in [fv], + * assign [t'] ([t =/= t']) to [x] and thus fulfill the disequality + *) + let hlp var term = + Term.VarSet.mem var fv || + (match Env.var env term with Some u -> Term.VarSet.mem u fv | None -> false) + in + Term.VarMap.for_all hlp t + + let freevars env subst t = + Term.VarMap.fold (fun _ term acc -> + Term.VarSet.union acc @@ Subst.freevars env subst term + ) t Term.VarSet.empty + + let subsumed env subst t t' = + Subst.(subsumed env (of_map t') (of_map t)) + let simplify env subst ds = try let result = Term.VarMap.fold (fun var term acc -> - match refine env subst (Obj.magic var) term with + match refine env subst (Term.repr var) term with | Fulfiled -> raise Disequality_fulfilled | Violated -> acc | Refined delta -> delta @ acc @@ -218,30 +238,11 @@ module Disjunct : let reify env subst t = Term.VarMap.fold (fun var term xs -> Subst.(Binding.({var; term})::xs)) t [] - - let is_relevant env subst t fv = - (* left those disjuncts that contains binding only for variables from [fv], - * otherwise it's possible to pick binding (x =/= t) from disjunct for - * variable [x] that is not in [fv], - * assign [t'] ([t =/= t']) to [x] and thus fulfill the disequality - *) - Term.VarMap.for_all (fun var term -> - (Term.VarSet.mem var fv) || - (match Env.var env term with Some u -> Term.VarSet.mem u fv | None -> false) - ) t - - let freevars env subst t = - Term.VarMap.fold (fun _ term acc -> - Term.VarSet.union acc @@ Subst.freevars env subst term - ) t Term.VarSet.empty - - let subsumed env subst t t' = - Subst.(subsumed env (of_map t') (of_map t)) - end module Conjunct : sig + type t val empty : t @@ -270,6 +271,7 @@ module Conjunct : val reify : Env.t -> Subst.t -> t -> 'a -> Answer.t list end = struct + let next_id = ref 0 module M = Map.Make(struct type t = int let compare = (-) end) @@ -436,18 +438,17 @@ let recheck env subst cstore bs = let conj = Term.VarMap.find var cstore in let cstore = Term.VarMap.remove var cstore in update env subst (Conjunct.recheck env subst conj) cstore - with Not_found -> cstore in try let cstore = ListLabels.fold_left bs ~init:cstore - ~f:(fun cstore {Subst.Binding.var; term} -> + ~f:begin fun cstore Subst.Binding.{ var ; term } -> let cstore = helper var cstore in (* log "cstore = %a" pp cstore; *) match Env.var env term with | Some u -> helper u cstore | None -> cstore - ) + end in Some cstore with Disequality_violated -> None @@ -455,5 +456,4 @@ let recheck env subst cstore bs = let project env subst cstore fv = Conjunct.(split @@ project env subst (combine env subst cstore) fv) -let reify env subst cstore x = - Conjunct.reify env subst (combine env subst cstore) x +let reify env subst cstore = Conjunct.reify env subst (combine env subst cstore) diff --git a/src/core/Disequality.mli b/src/core/Disequality.mli index 41ac1692..20898091 100644 --- a/src/core/Disequality.mli +++ b/src/core/Disequality.mli @@ -19,7 +19,7 @@ type t (* [empty] empty disequality constraint store *) -val empty : t +val empty : t (* [add env subst diseq x y] adds constraint [x =/= y] into disequality constraint store *) val add : Env.t -> Subst.t -> t -> 'a -> 'a -> t option @@ -41,6 +41,7 @@ val merge_disjoint : Env.t -> Subst.t -> t -> t -> t module Answer : sig + (* [Answer.t] result of reification of disequality constraints *) type t @@ -52,4 +53,4 @@ module Answer : val reify : Env.t -> Subst.t -> t -> 'a -> Answer.t list -val pp: Format.formatter -> t -> unit +val pp : Format.formatter -> t -> unit diff --git a/src/core/Env.ml b/src/core/Env.ml index a3b54d1a..6405faeb 100644 --- a/src/core/Env.ml +++ b/src/core/Env.ml @@ -41,11 +41,9 @@ let check_exn env v = let var env x = match Term.var x with - | (Some v) as res -> check_exn env v; res + | (Some v) as res -> check_exn env v ; res | None -> None -let is_var env x = (var env x) <> None - let freevars env x = Term.fold (Term.repr x) ~init:Term.VarSet.empty ~fvar:(fun acc v -> Term.VarSet.add v acc) @@ -57,7 +55,7 @@ let is_open env x = try Term.iter (Term.repr x) ~fvar:(fun _ -> raise Open_Term) - ~fval:(fun _ -> ()); + ~fval:(fun _ -> ()) ; false with Open_Term -> true diff --git a/src/core/Env.mli b/src/core/Env.mli index b672b278..fe849a02 100644 --- a/src/core/Env.mli +++ b/src/core/Env.mli @@ -26,9 +26,9 @@ type t (** Creates a new environment. Environment has an internal state: a counter of the variable index introduced last. That's why {!empty} has extra unit argument. *) -val empty : unit -> t +val empty : unit -> t -val create : anchor:Term.Var.env -> t +val create : anchor:Term.Var.env -> t (** Creating a fresh variable takes an extra argument [scope] to decide if storing a substituted value inside a variable is OKay. @@ -36,21 +36,20 @@ val create : anchor:Term.Var.env -> t The idea was recommended by {{: https://github.com/michaelballantyne/faster-minikanren#set-var-val }faster-miniKanren} implementation. *) -val fresh : scope:Term.Var.scope -> t -> 'a +val fresh : scope:Term.Var.scope -> t -> 'a -val check : t -> Term.Var.t -> bool +val check : t -> Term.Var.t -> bool -val check_exn : t -> Term.Var.t -> unit +val check_exn : t -> Term.Var.t -> unit -val is_var : t -> 'a -> bool +(* See [Term.var] *) +val var : t -> 'a -> Term.Var.t option -val var : t -> 'a -> Term.Var.t option +val freevars : t -> 'a -> Term.VarSet.t -val freevars : t -> 'a -> Term.VarSet.t +val is_open : t -> 'a -> bool -val is_open : t -> 'a -> bool - -val equal : t -> t -> bool +val equal : t -> t -> bool (** Essentially, a reader monad over Env.t. Useful for reification. *) module Monad : sig diff --git a/src/core/Logic.ml b/src/core/Logic.ml index fa3ec3e5..dab28426 100644 --- a/src/core/Logic.ml +++ b/src/core/Logic.ml @@ -26,7 +26,8 @@ module List = Stdlib.List @type 'a logic = | Var of GT.int * 'a logic GT.list -| Value of 'a with show, gmap, html, eq, compare, foldl, foldr, fmt +| Value of 'a +with show, gmap, html, eq, compare, foldl, foldr, fmt let logic = {logic with plugins = @@ -67,8 +68,8 @@ exception Not_a_value let to_logic x = Value x let from_logic = function -| Value x -> x -| Var (n, _) -> raise Not_a_value +| Var _ -> raise Not_a_value +| Value x -> x type 'a ilogic @@ -95,8 +96,8 @@ module Reifier = struct let prj onvar env t = match reify env t with - | Value x -> x | Var (v, _) -> onvar v + | Value x -> x let apply r (env, a) = r env a @@ -108,14 +109,10 @@ module Reifier = struct let rec fix f = fun env eta -> f (fix f) env eta - let rework : - 'a 'b. - fv:('a Env.m -> 'b Env.m) - -> ('a logic Env.m -> 'b logic Env.m) - -> 'a logic Env.m - -> 'b logic Env.m - = - fun ~fv fdeq x -> + let rework : 'a 'b. fv:('a Env.m -> 'b Env.m) + -> ('a logic Env.m -> 'b logic Env.m) + -> 'a logic Env.m -> 'b logic Env.m + = fun ~fv fdeq x -> let open Env.Monad in let open Env.Monad.Syntax in let* x = x in @@ -126,7 +123,6 @@ module Reifier = struct | Value t -> let+ inner = fv (return t) in Value inner - ;; let rec zed f x = f (zed f) x end diff --git a/src/core/Logic.mli b/src/core/Logic.mli index 60f4b66c..0530b3ef 100644 --- a/src/core/Logic.mli +++ b/src/core/Logic.mli @@ -22,7 +22,8 @@ (** A type of a logic value *) @type 'a logic = | Var of GT.int * 'a logic GT.list -| Value of 'a with show, gmap, html, eq, compare, foldl, foldr, fmt +| Value of 'a +with show, gmap, html, eq, compare, foldl, foldr, fmt (** [to_logic x] makes a logic value from a regular one *) val to_logic : 'a -> 'a logic diff --git a/src/core/Subst.ml b/src/core/Subst.ml index 5ee7f941..b7a6da7c 100644 --- a/src/core/Subst.ml +++ b/src/core/Subst.ml @@ -16,6 +16,7 @@ * See the GNU Library General Public License version 2 for more details * (enclosed in the file COPYING). *) + IFDEF STATS THEN type stat = {mutable walk_count : int} @@ -24,20 +25,18 @@ let stat = {walk_count = 0} let walk_counter () = stat.walk_count let walk_incr () = stat.walk_count <- stat.walk_count + 1 END + (* to avoid clash with Std.List (i.e. logic list) *) module List = Stdlib.List module Binding = struct + type t = { var : Term.Var.t ; term : Term.t } - let is_relevant env vs {var; term} = - (Term.VarSet.mem var vs) || - (match Env.var env term with Some v -> Term.VarSet.mem v vs | None -> false) - let equal {var=v; term=t} {var=u; term=p} = (Term.Var.equal v u) || (Term.equal t p) @@ -48,11 +47,11 @@ module Binding = let hash {var; term} = Hashtbl.hash (Term.Var.hash var, Term.hash term) let pp ppf {var; term} = - Format.fprintf ppf "{ var.idx = %d; term=%s }" var.Term.Var.index (Term.show term) + Format.fprintf ppf "{ var.idx = %d; term=%a }" var.Term.Var.index Term.pp term end let varmap_of_bindings : Binding.t list -> Term.t Term.VarMap.t = - Stdlib.List.fold_left (fun (acc: _ Term.VarMap.t) Binding.{var;term} -> + Stdlib.List.fold_left (fun (acc: _ Term.VarMap.t) Binding.{var;term} -> assert (not (Term.VarMap.mem var acc)); Term.VarMap.add var term acc ) @@ -62,44 +61,41 @@ type t = Term.t Term.VarMap.t let empty = Term.VarMap.empty -let pp ppf (s: t) = - Format.fprintf ppf "{subst| "; - Term.VarMap.iter (fun var term -> Format.fprintf ppf "%a |- %a; " Term.pp (Obj.repr var) Term.pp term) s; - Format.fprintf ppf "|subst}" - - -let of_list = - ListLabels.fold_left ~init:empty ~f:(let open Binding in fun subst {var; term} -> - if not @@ Term.VarMap.mem var subst then - Term.VarMap.add var term subst - else - invalid_arg "OCanren fatal (Subst.of_list): invalid substituion" - ) - let of_map m = m -let split s = Term.VarMap.fold (fun var term xs -> Binding.({var; term})::xs) s [] +let split s = Term.VarMap.fold (fun var term xs -> Binding.{ var ; term }::xs) s [] + +let pp ppf s = + let open Format in + fprintf ppf "{subst| " ; + Term.VarMap.iter (fun x t -> fprintf ppf "%a |- %a; " Term.pp (Term.repr x) Term.pp t) s ; + fprintf ppf "|subst}" type lterm = Var of Term.Var.t | Value of Term.t -let walk env subst x = +let walk env subst = + (* walk var *) - let rec walkv env subst v = + let rec walkv v = let () = IFDEF STATS THEN walk_incr () ELSE () END in - Env.check_exn env v; + Env.check_exn env v ; + match v.Term.Var.subst with - | Some term -> walkt env subst (Obj.magic term) + | Some term -> walkt term | None -> - try walkt env subst (Term.VarMap.find v subst) + try walkt (Term.VarMap.find v subst) with Not_found -> Var v + (* walk term *) - and walkt env subst t = + and walkt t = let () = IFDEF STATS THEN walk_incr () ELSE () END in + match Env.var env t with - | Some v -> walkv env subst v + | Some v -> walkv v | None -> Value t in - walkv env subst x + + walkv (* same as [Term.map] but performs [walk] on the road *) let map ~fvar ~fval env subst x = @@ -121,27 +117,14 @@ let iter ~fvar ~fval env subst x = in Term.iter x ~fval ~fvar:deepfvar -(* same as [Term.fold] but performs [walk] on the road *) -let fold ~fvar ~fval ~init env subst x = - let rec deepfvar acc v = - Env.check_exn env v; - match walk env subst v with - | Var v -> fvar acc v - | Value x -> Term.fold x ~fval ~fvar:deepfvar ~init:acc - in - Term.fold x ~init ~fval ~fvar:deepfvar - exception Occurs_check -let rec occurs env subst var term = - iter env subst term - ~fvar:(fun v -> if Term.Var.equal v var then raise Occurs_check) - ~fval:(fun x -> ()) +let rec occurs env subst var term = iter env subst term ~fval:(fun _ -> ()) + ~fvar:(fun v -> if Term.Var.equal v var then raise Occurs_check) -let extend ~scope env subst var term = - (* if occurs env subst var term then raise Occurs_check *) - if Runconf.do_occurs_check () then occurs env subst var term; - (* assert (VarEnv.var env var <> VarEnv.var env term); *) +(* [var] must be free in [subst], [term] must not be the same variable *) +let extend ~scope env subst var term = + if Runconf.do_occurs_check () then occurs env subst var term ; (* It is safe to modify variables destructively if the case of scopes match. * There are two cases: @@ -150,97 +133,68 @@ let extend ~scope env subst var term = * 2) If we do unification after a fresh, then in case of failure it doesn't matter if * the variable is be distructively substituted: we will not look on it in future. *) - if (scope = var.Term.Var.scope) && (scope <> Term.Var.non_local_scope) - then begin - var.subst <- Some (Obj.repr term); + if scope = var.Term.Var.scope && scope <> Term.Var.non_local_scope then begin + var.subst <- Some term ; subst - end - else - Term.VarMap.add var (Term.repr term) subst + end else + Term.VarMap.add var term subst exception Unification_failed -let log fmt = - if false - then Format.kasprintf (Format.printf "%s\n%!") fmt - else Format.ifprintf Format.std_formatter fmt - let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) env subst x y = (* The idea is to do the unification and collect the unification prefix during the process *) let extend var term (prefix, subst) = let subst = extend ~scope env subst var term in - (Binding.({var; term})::prefix, subst) + Binding.{ var ; term }::prefix, subst in - let rec helper x y acc = - (* log "unify '%s' and '%s'" (Term.show x) (Term.show y); *) - let open Term in - fold2 x y ~init:acc - ~fvar:(fun ((_, subst) as acc) x y -> - match walk env subst x, walk env subst y with - | Var x, Var y -> - if Var.equal x y then acc else extend x (Term.repr y) acc - | Var x, Value y -> extend x y acc - | Value x, Var y -> extend y x acc - | Value x, Value y -> helper x y acc - ) - ~fval:(fun acc x y -> - if x = y then acc else raise Unification_failed - ) - ~fk:(fun ((_, subst) as acc) l v y -> - if subsume && (l = Term.R) - then raise Unification_failed - else match walk env subst v with - | Var v -> extend v y acc - | Value x -> helper x y acc - ) + + let rec helper x y acc = Term.fold2 x y ~init:acc + ~fvar:begin fun ((_, subst) as acc) x y -> + match walk env subst x, walk env subst y with + | Var x, Var y -> + if Term.Var.equal x y then acc + else extend x (Term.repr y) acc + | Var x, Value y -> extend x y acc + | Value x, Var y -> extend y x acc + | Value x, Value y -> helper x y acc + end + ~fval:begin fun acc x y -> + if x = y then acc + else raise Unification_failed + end + ~fk:begin fun ((_, subst) as acc) l v y -> + if subsume && l = Term.R + then raise Unification_failed + else match walk env subst v with + | Var v -> extend v y acc + | Value x -> helper x y acc + end in + try let x, y = Term.(repr x, repr y) in Some (helper x y ([], subst)) with Term.Different_shape _ | Unification_failed | Occurs_check -> None -let apply env subst x = Obj.magic @@ - map env subst (Term.repr x) - ~fvar:(fun v -> Term.repr v) - ~fval:(fun x -> Term.repr x) - let unify_map env subst map = - let vars, terms = - Term.VarMap.fold (fun v term acc -> (v :: fst acc, term :: snd acc)) map ([],[]) - in - (* log "var = %s" (Term.show (Obj.magic (apply env subst vars))); *) - (* log "terms = %s" (Term.show (Obj.magic (apply env subst terms))); *) - unify env subst (Obj.magic vars) (Obj.magic terms) - - -let freevars env subst x = - Env.freevars env @@ apply env subst x + let vars, terms = Term.VarMap.fold (fun v t (vs, ts) -> Term.repr v :: vs, t::ts) map ([], []) in + unify env subst vars terms -let is_bound = Term.VarMap.mem +let merge_disjoint env = Term.VarMap.union @@ fun _ _ -> + invalid_arg "OCanren fatal (Subst.merge_disjoint): substitutions intersect" -let merge env subst1 subst2 = Term.VarMap.fold (fun var term -> function - | Some s -> begin - match unify env s (Obj.magic var) term with - | Some (_, s') -> Some s' - | None -> None - end - | None -> None -) subst1 (Some subst2) +let subsumed env subst = Term.VarMap.for_all @@ fun var term -> + match unify env subst var term with + | Some ([], _) -> true + | _ -> false -let merge_disjoint env = - Term.VarMap.union (fun _ _ -> - invalid_arg "OCanren fatal (Subst.merge_disjoint): substitutions intersect" - ) +let apply env subst x = Obj.magic @@ map env subst (Term.repr x) ~fvar:Term.repr ~fval:Term.repr -let subsumed env subst = - Term.VarMap.for_all (fun var term -> - match unify env subst (Obj.magic var) term with - | Some ([], _) -> true - | _ -> false - ) +let freevars env subst x = Env.freevars env @@ apply env subst x module Answer = struct + type t = Term.t let subsumed env x y = @@ -249,7 +203,4 @@ module Answer = | None -> false end -let reify env subst x = - map env subst (Term.repr x) - ~fvar:(fun v -> Term.repr v) - ~fval:(fun x -> Term.repr x) +let reify = apply diff --git a/src/core/Subst.mli b/src/core/Subst.mli index 62d2257c..43fbc651 100644 --- a/src/core/Subst.mli +++ b/src/core/Subst.mli @@ -19,42 +19,30 @@ module Binding : sig + type t = { var : Term.Var.t ; term : Term.t } - val is_relevant : Env.t -> Term.VarSet.t -> t -> bool - val equal : t -> t -> bool val compare : t -> t -> int val hash : t -> int + val pp: Format.formatter -> t -> unit end -val varmap_of_bindings: Binding.t list -> Term.t Term.VarMap.t +val varmap_of_bindings : Binding.t list -> Term.t Term.VarMap.t type t -val pp: Format.formatter -> t -> unit - val empty : t -val of_list : Binding.t list -> t -val of_map : Term.t Term.VarMap.t -> t +val of_map : Term.t Term.VarMap.t -> t val split : t -> Binding.t list -(* [apply env subst x] - applies [subst] to term [x], - * i.e. replaces every variable to relevant binding in [subst]; - *) -val apply : Env.t -> t -> 'a -> 'a - -(* [is_bound x subst] - checks whether [x] is bound by [subst] *) -val is_bound : Term.Var.t -> t -> bool - -(* [freevars env subst x] - returns all free-variables of term [x] *) -val freevars : Env.t -> t -> 'a -> Term.VarSet.t +val pp : Format.formatter -> t -> unit (* [unify ~subsume ~scope env subst x y] performs unification of two terms [x] and [y] in [subst]. * Unification is a process of finding substituion [s] s.t. [s(x) = s(y)]. @@ -69,20 +57,26 @@ val freevars : Env.t -> t -> 'a -> Term.VarSet.t *) val unify : ?subsume:bool -> ?scope:Term.Var.scope -> Env.t -> t -> 'a -> 'a -> (Binding.t list * t) option -val unify_map: Env.t -> t -> Term.t Term.VarMap.t -> (Binding.t list * t) option +val unify_map : Env.t -> t -> Term.t Term.VarMap.t -> (Binding.t list * t) option val merge_disjoint : Env.t -> t -> t -> t -(* [merge env s1 s2] merges two substituions *) -val merge : Env.t -> t -> t -> t option - (* [subsumed env s1 s2] checks that [s1] is subsumed by [s2] (i.e. [s2] is more general than [s1]). * Subsumption relation forms a partial order on the set of substitutions. *) val subsumed : Env.t -> t -> t -> bool +(* [apply env subst x] - applies [subst] to term [x], + * i.e. replaces every variable to relevant binding in [subst]; + *) +val apply : Env.t -> t -> 'a -> 'a + +(* [freevars env subst x] - returns all free-variables of term [x] *) +val freevars : Env.t -> t -> 'a -> Term.VarSet.t + module Answer : sig + type t = Term.t (* [subsumed env x y] checks that [x] is subsumed by [y] (i.e. [y] is more general than [x]) *) diff --git a/src/core/Term.ml b/src/core/Term.ml index daa8079a..2d8ed6d6 100644 --- a/src/core/Term.ml +++ b/src/core/Term.ml @@ -22,11 +22,14 @@ open Printf (* to avoid clash with Std.List (i.e. logic list) *) module List = Stdlib.List +type t = Obj.t + module Var = struct + type term = t type env = int type scope = int - type anchor = int list + type anchor = int ref let tabling_env = -1 @@ -38,15 +41,15 @@ module Var = let scope = ref 0 in fun () -> (incr scope; !scope) - let global_anchor = [-8] + let global_anchor = ref (-8) type t = { anchor : anchor; env : env; index : int; - mutable subst : Obj.t option; + mutable subst : term option; scope : scope; - constraints : Obj.t list + constraints : term list } let make ~env ~scope index = { @@ -63,11 +66,19 @@ module Var = let scope = 0 in make ~env ~scope 0 - let valid_anchor anchor = - anchor == global_anchor + let var_tag, var_size = + let dummy = Obj.repr dummy in + Obj.tag dummy, Obj.size dummy + + let is_valid_anchor anchor = anchor == global_anchor - let reify r {index; constraints} = - (index, List.map (fun x -> r @@ Obj.obj x) constraints) + let has_var_structure tx sx x = + if tx = var_tag && sx = var_size then + let anchor = (Obj.obj x).anchor in + (Obj.is_block @@ Obj.repr anchor) && is_valid_anchor anchor + else false + + let reify r { index ; constraints } = index, List.map (fun x -> r @@ Obj.obj x) constraints let equal x y = (x.index = y.index) && (x.env = y.env) @@ -77,6 +88,14 @@ module Var = let hash x = Hashtbl.hash (x.env, x.index) + let describe ppf { index } = Format.fprintf ppf "_.%d" index + + let pp ppt ppf x = match x.constraints with + | [] -> describe ppf x + | cs -> + let open Format in + let ppcs = pp_print_list ~pp_sep:(fun ppf () -> fprintf ppf "; ") ppt in + fprintf ppf "%a{=/= %a}" describe x ppcs cs end module VarSet = Set.Make(Var) @@ -86,68 +105,88 @@ module VarMap = struct include Map.Make(Var) - let update k f m = - match f (try Some (find k m) with Not_found -> None) with - | Some x -> add k x m - | None -> remove k m - let iteri f m = let i = ref 0 in iter (fun k v -> f !i k v; incr i) m end -type t = Obj.t type value = Obj.t let repr = Obj.repr - -let var_tag, var_size = - let dummy = Obj.repr Var.dummy in - Obj.tag dummy, Obj.size dummy - -let has_var_structure tx sx x = - if tx = var_tag && sx = var_size - then ( - let anchor = (Obj.obj x : Var.t).Var.anchor in - (Obj.is_block @@ Obj.repr anchor) && Var.valid_anchor anchor) - else false -;; +let obj = Obj.obj let is_box t = - if (t <= Obj.last_non_constant_constructor_tag) && - (t >= Obj.first_non_constant_constructor_tag) - then true - else false - -let is_var x = - let x = Obj.repr x in - let tx = Obj.tag x in - is_box tx && has_var_structure tx (Obj.size x) x -;; + t <= Obj.last_non_constant_constructor_tag && + t >= Obj.first_non_constant_constructor_tag let is_int = (=) Obj.int_tag let is_str = (=) Obj.string_tag let is_dbl = (=) Obj.double_tag -let is_valid_tag t = (is_int t) || (is_str t) || (is_dbl t) +let is_val t = is_int t || is_str t || is_dbl t -let is_valid_tag_exn t = - if is_valid_tag t then () else failwith (sprintf "OCanren fatal: invalid value tag (%d)" t) +let check_val t = + if not @@ is_val t then invalid_arg @@ sprintf "OCanren fatal: invalid value tag (%d)" t let var x = let x = Obj.repr x in let tx = Obj.tag x in if is_box tx then let sx = Obj.size x in - if has_var_structure tx sx x then Some (Obj.magic x) else None + if Var.has_var_structure tx sx x then Some (obj x) else None else None +let pp = + let open Format in + let rec helper ppf x = + let tx = Obj.tag x in + if is_box tx + then ( + let sx = Obj.size x in + if Var.has_var_structure tx sx x + then ( + let v = obj x in + match v.Var.constraints with + | [] -> Var.describe ppf v + | cs -> + fprintf + ppf + "%a{=/= %a}" + Var.describe + v + (pp_print_list ~pp_sep:(fun ppf () -> fprintf ppf "; ") helper) + cs) + else ( + let rec inner i : unit = + if i < sx + then ( + if i > 0 then fprintf ppf ", "; + helper ppf (Obj.field x i); + inner (i + 1)) + in + fprintf ppf "boxed %d <" tx; + inner 0; + fprintf ppf ">")) + else ( + check_val tx; + if tx = Obj.int_tag + then fprintf ppf "int<%d>" @@ obj x + else if tx = Obj.string_tag + then fprintf ppf "string<%s>" @@ obj x + else if tx = Obj.double_tag + then fprintf ppf "double<%e>" @@ obj x + else failwith "Dynamic pretty printing of some special tags is not supported") + in + helper + +let show x = Format.asprintf "%a" pp x + let rec map ~fvar ~fval x = let tx = Obj.tag x in if (is_box tx) then let sx = Obj.size x in - if has_var_structure tx sx x then - fvar @@ Obj.magic x + if Var.has_var_structure tx sx x then + fvar @@ obj x else let y = Obj.dup x in for i = 0 to sx - 1 do @@ -155,7 +194,7 @@ let rec map ~fvar ~fval x = done; y else begin - is_valid_tag_exn tx; + check_val tx; fval x end @@ -163,50 +202,23 @@ let rec iter ~fvar ~fval x = let tx = Obj.tag x in if (is_box tx) then let sx = Obj.size x in - if has_var_structure tx sx x then - fvar @@ Obj.magic x + if Var.has_var_structure tx sx x then + fvar @@ obj x else for i = 0 to sx - 1 do iter ~fvar ~fval (Obj.field x i) done; else begin - is_valid_tag_exn tx; + check_val tx; fval x end -let rec show x = - let tx = Obj.tag x in - if (is_box tx) then - let sx = Obj.size x in - if has_var_structure tx sx x then - let v = Obj.magic x in - match v.Var.constraints with - | [] -> Printf.sprintf "_.%d" v.Var.index - | cs -> Printf.sprintf "_.%d{=/= %s}" v.Var.index (String.concat "; " @@ List.map show cs) - else - let rec inner i = - if i < sx then - (show @@ Obj.field x i)::(inner (i+1)) - else [] - in - Printf.sprintf "boxed %d <%s>" tx (String.concat ", " @@ inner 0) - else begin - is_valid_tag_exn tx; - if tx = Obj.int_tag then - Printf.sprintf "int<%d>" @@ Obj.magic x - else if tx = Obj.string_tag then - Printf.sprintf "string<%s>" @@ Obj.magic x - else if tx = Obj.double_tag then - Printf.sprintf "double<%e>" @@ Obj.magic x - else assert false - end - let rec fold ~fvar ~fval ~init x = let tx = Obj.tag x in if (is_box tx) then let sx = Obj.size x in - if has_var_structure tx sx x then - fvar init @@ Obj.magic x + if Var.has_var_structure tx sx x then + fvar init @@ obj x else let rec inner i acc = if i < sx then @@ -216,7 +228,7 @@ let rec fold ~fvar ~fval ~init x = in inner 0 init else begin - is_valid_tag_exn tx; + check_val tx; fval init x end @@ -229,10 +241,10 @@ let rec fold2 ~fvar ~fval ~fk ~init x y = match is_box tx, is_box ty with | true, true -> begin let sx, sy = Obj.size x, Obj.size y in - match has_var_structure tx sx x, has_var_structure ty sy y with - | true, true -> fvar init (Obj.magic x) (Obj.magic y) - | true, false -> fk init L (Obj.magic x) y - | false, true -> fk init R (Obj.magic y) x + match Var.has_var_structure tx sx x, Var.has_var_structure ty sy y with + | true, true -> fvar init (obj x) (obj y) + | true, false -> fk init L (obj x) y + | false, true -> fk init R (obj y) x | false, false -> if (tx = ty) && (sx = sy) then let fx, fy = Obj.field x, Obj.field y in @@ -246,16 +258,16 @@ let rec fold2 ~fvar ~fval ~fk ~init x y = else raise (Different_shape (tx, ty)) end | true, false -> - is_valid_tag_exn ty; + check_val ty; let sx = Obj.size x in - if has_var_structure tx sx x then fk init L (Obj.magic x) y else raise (Different_shape (tx, ty)) + if Var.has_var_structure tx sx x then fk init L (obj x) y else raise (Different_shape (tx, ty)) | false, true -> - is_valid_tag_exn tx; + check_val tx; let sy = Obj.size y in - if has_var_structure ty sy y then fk init R (Obj.magic y) x else raise (Different_shape (tx, ty)) + if Var.has_var_structure ty sy y then fk init R (obj y) x else raise (Different_shape (tx, ty)) | false, false -> - is_valid_tag_exn tx; - is_valid_tag_exn ty; + check_val tx; + check_val ty; if tx = ty then fval init x y else raise (Different_shape (tx, ty)) @@ -292,51 +304,3 @@ let rec compare x y = let rec hash x = fold x ~init:1 ~fvar:(fun acc v -> Hashtbl.hash (Var.hash v, List.fold_left (fun acc x -> Hashtbl.hash (acc, hash x)) acc v.Var.constraints)) ~fval:(fun acc x -> Hashtbl.hash (acc, Hashtbl.hash x)) - -let describe_var ppf Var.{ index } = Format.fprintf ppf "_.%d" index - -let pp = - let open Format in - let rec helper ppf x = - let tx = Obj.tag x in - if is_box tx - then ( - let sx = Obj.size x in - if has_var_structure tx sx x - then ( - let v = Obj.magic x in - match v.Var.constraints with - | [] -> describe_var ppf v - | cs -> - fprintf - ppf - "%a{=/= %a}" - describe_var - v - (pp_print_list ~pp_sep:(fun ppf () -> fprintf ppf "; ") helper) - cs) - else ( - let rec inner i : unit = - if i < sx - then ( - if i > 0 then fprintf ppf ", "; - helper ppf (Obj.field x i); - inner (i + 1)) - in - fprintf ppf "boxed %d <" tx; - inner 0; - fprintf ppf ">")) - else ( - is_valid_tag_exn tx; - if tx = Obj.int_tag - then fprintf ppf "int<%d>" @@ Obj.magic x - else if tx = Obj.string_tag - then fprintf ppf "string<%s>" @@ Obj.magic x - else if tx = Obj.double_tag - then fprintf ppf "double<%e>" @@ Obj.magic x - else failwith "Dynamic pretty printing of some special tags is not supported") - in - fun ppf x -> helper ppf (Obj.repr x) -;; - -let show x = Format.asprintf "%a" pp x diff --git a/src/core/Term.mli b/src/core/Term.mli index a03ccf33..f70c1b6f 100644 --- a/src/core/Term.mli +++ b/src/core/Term.mli @@ -19,43 +19,40 @@ (* [Term] encapsulates unsafe operations on untyped OCaml's values extended with the logic variables *) +(* [t] type of untyped OCaml term *) +type t = Obj.t + (* [Var] logic variables and operations on them *) module Var : sig + type term = t type env = int - type scope - type anchor type t = { anchor : anchor; env : env; index : int; - mutable subst : Obj.t option; + mutable subst : term option; scope : scope; - constraints : Obj.t list + constraints : term list } val tabling_env : env val non_local_scope : scope - val new_scope : unit -> scope - val valid_anchor : anchor -> bool - - val dummy : t - val make : env:env -> scope:scope -> int -> t val reify : ('a -> 'b) -> t -> int * 'b list val equal : t -> t -> bool - val compare : t -> t -> int - val hash : t -> int + + val describe : Format.formatter -> t -> unit end module VarSet : Set.S with type elt = Var.t @@ -64,21 +61,19 @@ module VarTbl : Hashtbl.S with type key = Var.t module VarMap : sig - include Map.S with type key = Var.t - val update : key -> ('a option -> 'a option) -> 'a t -> 'a t + include Map.S with type key = Var.t val iteri: (int -> key -> 'a -> unit) -> 'a t -> unit - end -(* [t] type of untyped OCaml term *) -type t = Obj.t - type value val repr : 'a -> t +val show : t -> string +val pp : Format.formatter -> t -> unit + (* [var x] if [x] is logic variable returns it, otherwise returns [None] *) val var : 'a -> Var.t option @@ -116,7 +111,3 @@ val fold2 : val equal : t -> t -> bool val compare : t -> t -> int val hash : t -> int - -val show : t -> string -val describe_var : Format.formatter -> Var.t -> unit -val pp : Format.formatter -> 'a -> unit