Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 110 additions & 117 deletions src/core/Core.ml
Original file line number Diff line number Diff line change
Expand Up @@ -60,81 +60,73 @@ module List = Stdlib.List

module Answer :
sig

(* [Answer.t] - a type that represents (untyped) answer to a query *)
type t

(* [make env t] creates the answer from the environment and term (with constrainted variables) *)
val make : Env.t -> Term.t -> t

(* [lift env a] lifts the answer into different environment, replacing all variables consistently *)
val lift : Env.t -> t -> t

(* [env a] returns an environment of the answer *)
val env : t -> Env.t

(* [unctr_term a] returns a term with unconstrained variables *)
val unctr_term : t -> Term.t

(* [ctr_term a] returns a term with constrained variables *)
val ctr_term : t -> Term.t

(* [unctr_term a] returns a term with unconstrained variables *)
val unctr_term : t -> Term.t

(* [disequality a] returns all disequality constraints on variables in term as a list of bindings *)
val disequality : t -> Subst.Binding.t list

(* [lift env a] lifts the answer into different environment, replacing all variables consistently *)
val lift : Env.t -> t -> t

(* [equal t t'] syntactic equivalence (not an alpha-equivalence) *)
val equal : t -> t -> bool

(* [hash t] hashing that is consistent with syntactic equivalence *)
val hash : t -> int
end = struct

type t = Env.t * Term.t

let make env t = (env, t)

let env (env, _) = env

let unctr_term (_, t) =
Term.map t
~fval:(fun x -> Term.repr x)
~fvar:(fun v -> Term.repr {v with Term.Var.constraints = []})

let ctr_term (_, t) = t

let unctr_term (_, t) = Term.map t ~fval:Term.repr
~fvar:(fun v -> Term.repr { v with Term.Var.constraints = [] })

let disequality (env, t) =
let rec helper acc x =
Term.fold x ~init:acc
~fval:(fun acc _ -> acc)
~fvar:(fun acc var ->
ListLabels.fold_left var.Term.Var.constraints ~init:acc
~f:(fun acc ctr_term ->
let ctr_term = Term.repr ctr_term in
let var = {var with Term.Var.constraints = []} in
let term = unctr_term @@ (env, ctr_term) in
let acc = Subst.(Binding.({var; term}))::acc in
helper acc ctr_term
)
)
let rec helper acc x = Term.fold x ~init:acc ~fval:(fun acc _ -> acc)
~fvar:begin fun acc ctr_var ->
let var = { ctr_var with Term.Var.constraints = [] } in
ListLabels.fold_left ctr_var.Term.Var.constraints ~init:acc
~f:begin fun acc ctr_term ->
let term = unctr_term (env, ctr_term) in
let acc = Subst.Binding.{ var ; term }::acc in
helper acc ctr_term
end
end
in
helper [] t

let lift env' (env, t) =
let vartbl = Term.VarTbl.create 31 in
let rec helper x =
Term.map x
~fval:(fun x -> Term.repr x)
~fvar:(fun v -> Term.repr @@
try
Term.VarTbl.find vartbl v
with Not_found ->
let new_var = Env.fresh ~scope:Term.Var.non_local_scope env' in
Term.VarTbl.add vartbl v new_var;
{new_var with Term.Var.constraints =
List.map (fun x -> helper x) v.Term.Var.constraints
|> List.sort Term.compare
}
)
let rec helper x = Term.map x ~fval:Term.repr
~fvar:begin fun v -> Term.repr @@
try Term.VarTbl.find vartbl v
with Not_found ->
let new_var = Env.fresh ~scope:Term.Var.non_local_scope env' in
Term.VarTbl.add vartbl v new_var ;
let ctr = List.map helper v.Term.Var.constraints |> List.sort Term.compare in
{ new_var with Term.Var.constraints = ctr }
end
in
(env', helper t)
env', helper t

let check_envs_exn env env' =
if Env.equal env env' then () else
Expand All @@ -144,58 +136,62 @@ module Answer :
check_envs_exn env env';
Term.equal t t'

let hash (env, t) = Term.hash t
let hash (_, t) = Term.hash t
end

module Prunes : sig

type rez = Violated | NonViolated
type ('a, 'b) reifier = ('a,'b) Reifier.t
type ('a, 'b) reifier = ('a, 'b) Reifier.t
type 'b cond = 'b -> bool
type t

val empty : t
val empty : t
val extend : t -> Term.VarTbl.key -> ('a, 'b) reifier -> 'b cond -> t

val recheck : t -> Env.t -> Subst.t -> rez
val check_last : t -> Env.t -> Subst.t -> rez
val extend : t -> Term.VarTbl.key -> ('a, 'b) reifier -> 'b cond -> t
end = struct

type rez = Violated | NonViolated
type ('a, 'b) reifier = ('a,'b) Reifier.t
type reifier_untyped = Obj.t
type ('a, 'b) reifier = ('a, 'b) Reifier.t
type 'b cond = 'b -> bool

type reifier_untyped = Obj.t
type cond_untyped = Obj.t -> bool

let make_untyped : ('a, 'b) reifier -> 'b cond -> reifier_untyped * cond_untyped =
fun a b -> Obj.magic (a,b)
fun a b -> Obj.magic (a, b)

type t = (Term.t * (reifier_untyped * cond_untyped)) list

type t = (Obj.t * (reifier_untyped * cond_untyped)) list
let empty = []

exception Fail
let extend map term rr cond =
let new_item = make_untyped rr cond in
(Term.repr term, new_item) :: map

let check_last map env subst =
try
let (term, (reifier, checker)) = List.hd map in
let reifier : (_,_) reifier = Obj.obj reifier in
let reified = reifier env (Obj.magic @@ Subst.apply env subst term) in
if not (checker reified) then raise Fail;
NonViolated
with Not_found -> NonViolated
| Fail -> Violated
exception Fail

let recheck (ps: t) env s =
try
ps |> List.iter (fun (k, (reifier, checker)) ->
let reifier : (_,_) Reifier.t = Obj.obj reifier in
ps |> List.iter begin fun (k, (reifier, checker)) ->
let reifier : (_, _) Reifier.t = Obj.obj reifier in
let reified = reifier env (Obj.magic @@ Subst.apply env s k) in
if not (checker reified) then raise Fail
);
end ;
NonViolated
with Fail -> Violated

let extend map term rr cond =
let new_item = make_untyped rr cond in
(Obj.repr term, new_item) :: map

let check_last map env subst =
try
let (term, (reifier, checker)) = List.hd map in
let reifier : (_, _) reifier = Obj.obj reifier in
let reified = reifier env (Obj.magic @@ Subst.apply env subst term) in
if not (checker reified) then raise Fail ;
NonViolated
with Not_found -> NonViolated
| Fail -> Violated
end

type prines_control =
Expand Down Expand Up @@ -244,6 +240,7 @@ module PrunesControl = struct
ans
)
end

(*
let do_skip_prunes = ref false
let prunes_checks_skipped = ref 0
Expand All @@ -253,8 +250,10 @@ let set_skip_prunes_count n =
assert (n>0);
max_prunes_skipped := n
*)

module State =
struct

type t =
{ env : Env.t
; subst : Subst.t
Expand Down Expand Up @@ -283,65 +282,59 @@ module State =

let new_scope st = {st with scope = Term.Var.new_scope ()}

let unify x y ({env; subst; ctrs; scope} as st) =
match Subst.unify ~scope env subst x y with
| None -> None
| Some (prefix, subst) ->
match Disequality.recheck env subst ctrs prefix with
| None -> None
| Some ctrs ->
let next_state = {st with subst; ctrs} in
if PrunesControl.is_exceeded ()
then begin
let () = PrunesControl.reset_cur_counter () in
match Prunes.recheck (prunes next_state) env subst with
| Prunes.Violated -> None
| NonViolated -> Some next_state
end else begin
(* print_endline "check skipped";*)
let () = PrunesControl.incr () in
Some next_state
end

let unify x y ({ env ; subst ; ctrs ; scope } as st) =
match Subst.unify ~scope env subst x y with
| None -> None
| Some (prefix, subst) ->
match Disequality.recheck env subst ctrs prefix with
| None -> None
| Some ctrs ->
let next_state = { st with subst ; ctrs } in
if PrunesControl.is_exceeded ()
then begin
let () = PrunesControl.reset_cur_counter () in
match Prunes.recheck (prunes next_state) env subst with
| Prunes.Violated -> None
| NonViolated -> Some next_state
end else begin
(* print_endline "check skipped"; *)
let () = PrunesControl.incr () in
Some next_state
end

let diseq x y ({env; subst; ctrs; scope} as st) =
let diseq x y ({ env ; subst ; ctrs ; scope } as st) =
match Disequality.add env subst ctrs x y with
| None -> None
| Some ctrs ->
match Prunes.recheck (prunes st) env subst with
| Prunes.Violated -> None
| NonViolated -> Some {st with ctrs}
match Prunes.recheck (prunes st) env subst with
| Prunes.Violated -> None
| NonViolated -> Some { st with ctrs }

(* returns always non-empty list *)
let reify x {env; subst; ctrs} =
(* always returns non-empty list *)
let reify x { env ; subst ; ctrs } =
let answ = Subst.reify env subst x in
match Disequality.reify env subst ctrs x with
| [] -> [Answer.make env answ]
| diseqs ->
ListLabels.map diseqs ~f:(fun diseq ->
let rec helper forbidden t =
Term.map t
~fval:(fun x -> Term.repr x)
~fvar:(fun v -> Term.repr @@
if List.mem v.Term.Var.index forbidden then v
else
{v with Term.Var.constraints =
Disequality.Answer.extract diseq v
|> List.filter (fun dt ->
match Env.var env dt with
| Some u -> not (List.mem u.Term.Var.index forbidden)
| None -> true
)
|> List.map (fun x -> helper (v.Term.Var.index::forbidden) x)
(* TODO: represent [Var.constraints] as [Set];
* TODO: hide all manipulations on [Var.t] inside [Var] module;
*)
|> List.sort Term.compare
}
)
in
Answer.make env (helper [] answ)
)
| [] -> (* [Answer.make env answ] *) assert false
| diseqs -> ListLabels.map diseqs ~f:begin fun diseq ->
let rec helper forbidden t = Term.map t ~fval:Term.repr
~fvar:begin fun v -> Term.repr @@
if Term.VarSet.mem v forbidden then v
else { v with Term.Var.constraints = Disequality.Answer.extract diseq v
|> List.filter begin fun dt ->
match Env.var env dt with
| Some u -> not @@ Term.VarSet.mem u forbidden
| None -> true
end
|> List.map (fun x -> helper (Term.VarSet.add v forbidden) x)
(* TODO: represent [Var.constraints] as [Set];
* TODO: hide all manipulations on [Var.t] inside [Var] module;
*)
|> List.sort Term.compare
}
end
in
Answer.make env @@ helper Term.VarSet.empty answ
end
end

let (!!!) = Obj.magic
Expand Down Expand Up @@ -797,6 +790,6 @@ let reify_in_empty reifier x =
reifier (State.env st) x

let trace_diseq : goal = fun st ->
Format.printf "%a\n%!" Disequality.pp (State.constraints st);
Format.printf "%a\n%!" Disequality.pp (State.constraints st) ;
success st

Loading