diff --git a/.gitignore b/.gitignore index 24c8d48..af16f01 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ notebooks/.ipynb_checkpoints/Prediction-based CA system-checkpoint.ipynb dist/pycona-0.2.4-py3-none-any.whl dist/pycona-0.2.4.tar.gz notebooks/.ipynb_checkpoints/Comparing different algorithms and methods-checkpoint.ipynb +testing.py diff --git a/pycona/active_algorithms/genacq.py b/pycona/active_algorithms/genacq.py index fc43212..34cb48c 100644 --- a/pycona/active_algorithms/genacq.py +++ b/pycona/active_algorithms/genacq.py @@ -68,7 +68,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.total_queries} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/active_algorithms/growacq.py b/pycona/active_algorithms/growacq.py index a3fc740..a1bdbfe 100644 --- a/pycona/active_algorithms/growacq.py +++ b/pycona/active_algorithms/growacq.py @@ -7,7 +7,7 @@ from ..answering_queries import Oracle, UserOracle from .. import Metrics from ..ca_environment import ProbaActiveCAEnv - +from ..utils import get_con_subset class GrowAcq(AlgorithmCAInteractive): """ @@ -67,6 +67,11 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos print(f"\nGrowAcq: calling inner_algorithm for {len(Y)}/{n_vars} variables") self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, X=Y, metrics=self.env.metrics) + # Add implied constraints from bias to cl + implied_constraints = get_con_subset(self.env.instance.bias, Y) + self.env.instance.cl.extend(implied_constraints) + self.env.instance.bias = [c for c in self.env.instance.bias if c not in set(implied_constraints)] # remove implied constraints from bias + if verbose >= 3: print("C_L: ", len(self.env.instance.cl)) print("B: ", len(self.env.instance.bias)) diff --git a/pycona/active_algorithms/mineacq.py b/pycona/active_algorithms/mineacq.py index a9b262f..2885058 100644 --- a/pycona/active_algorithms/mineacq.py +++ b/pycona/active_algorithms/mineacq.py @@ -66,7 +66,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.total_queries} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/active_algorithms/mquacq.py b/pycona/active_algorithms/mquacq.py index 2467f36..565e047 100644 --- a/pycona/active_algorithms/mquacq.py +++ b/pycona/active_algorithms/mquacq.py @@ -62,7 +62,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.membership_queries_count} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/active_algorithms/mquacq2.py b/pycona/active_algorithms/mquacq2.py index 2813b70..11ba0ea 100644 --- a/pycona/active_algorithms/mquacq2.py +++ b/pycona/active_algorithms/mquacq2.py @@ -69,7 +69,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.membership_queries_count} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generated_queries() diff --git a/pycona/active_algorithms/pquacq.py b/pycona/active_algorithms/pquacq.py index 6cf579d..b591085 100644 --- a/pycona/active_algorithms/pquacq.py +++ b/pycona/active_algorithms/pquacq.py @@ -62,7 +62,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.membership_queries_count} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/active_algorithms/quacq.py b/pycona/active_algorithms/quacq.py index 31352bd..b28c4d6 100644 --- a/pycona/active_algorithms/quacq.py +++ b/pycona/active_algorithms/quacq.py @@ -58,7 +58,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos if self.env.verbose >= 1: print(f"\nLearned {self.env.metrics.cl} constraints in " f"{self.env.metrics.membership_queries_count} queries.") - self.env.instance.bias = [] return self.env.instance self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/benchmarks/nqueens.py b/pycona/benchmarks/nqueens.py index cc3bfd9..976554f 100644 --- a/pycona/benchmarks/nqueens.py +++ b/pycona/benchmarks/nqueens.py @@ -3,7 +3,7 @@ from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar -def construct_nqueens_problem(n): +def construct_nqueens_problem(n=8): parameters = {"n": n} @@ -43,6 +43,4 @@ def construct_nqueens_problem(n): for c in oracle.constraints: print(c) - input("Press Enter to continue...") - return instance, oracle diff --git a/tests/test_finc.py b/tests/test_finc.py index 3ac5455..ca7e98f 100644 --- a/tests/test_finc.py +++ b/tests/test_finc.py @@ -77,5 +77,19 @@ def test_findc2_with_golomb4(self): learned_not_oracle += cp.any([~c for c in oracle.constraints]) assert not learned_not_oracle.solve() + # test growacq + alg = ca.GrowAcq(ca_env, alg) + li2 = alg.learn(instance, oracle) + + # oracle model imply learned? + oracle_not_learned = cp.Model(oracle.constraints) + oracle_not_learned += cp.any([~c for c in li2._cl]) + assert not oracle_not_learned.solve() + + # learned model imply oracle? + learned_not_oracle = cp.Model(li2._cl) + learned_not_oracle += cp.any([~c for c in oracle.constraints]) + assert not learned_not_oracle.solve() +