diff --git a/Deduce.lark b/Deduce.lark index cf286d1..8f9bca0 100644 --- a/Deduce.lark +++ b/Deduce.lark @@ -114,6 +114,7 @@ ident: IDENT -> ident | "if" term "then" term "else" term -> conditional | "if" term "then" term -> if_then_formula | "%" type + | "break" term_hi -> break_before ?term_list: -> empty | term -> single diff --git a/abstract_syntax.py b/abstract_syntax.py index 92adafd..3b86972 100644 --- a/abstract_syntax.py +++ b/abstract_syntax.py @@ -1079,10 +1079,10 @@ def op_arg_str(trm, arg): return "(" + str(arg) + ")" return str(arg) - - def do_function_call(loc, name, type_params, type_args, params, args, body, subst, env, return_type): + from deduce_debugger import on_function, after_function + on_function(name, loc, env, args, [base_name(p) for p in params]) fast_call = False if get_eval_all() and len(args) == 2 and isNat(args[0]) and isNat(args[1]): op = base_name(name) @@ -1164,9 +1164,28 @@ def do_function_call(loc, name, type_params, type_args, global recursion_depth print('<' * recursion_depth, str(ret)) recursion_depth -= 1 - + + after_function(name, loc, env, args, ret) return explicit_term_inst(ret) +@dataclass +class Breakpoint(Term): + point: Term + + def __str__(self): + return str(self.point) + + def uniquify(self, env): + self.point.uniquify(env) + from deduce_debugger import break_at_point + break_at_point(self.point.location) + + def reduce(self, env): + from deduce_debugger import on_statement, after_statement + on_statement(self.point.location, env) + ret = self.point.reduce(env) + after_statement(self.point.location, env) + return ret @dataclass class Call(Term): @@ -1219,7 +1238,7 @@ def __eq__(self, other): #print(str(self) + ' =? ' + str(other) + ' = ' + str(result)) return result - def reduce(self, env): + def reduce(self, env): fun = self.rator.reduce(env) if get_eval_all(): is_assoc = False @@ -1247,10 +1266,16 @@ def reduce(self, env): ret.type_args = self.type_args case Lambda(loc, ty, vars, body): + from deduce_debugger import on_function, after_function + name = rator_name(self.rator) if hasattr(fun, 'env'): + on_function(name, self.location, fun.env, args, param_names=[base_name(x[0]) for x in vars]) ret = self.do_call(loc, vars, body, args, fun.env) + after_function(name, self.location, fun.env, args, ret, param_names=[base_name(x[0]) for x in vars]) else: + on_function(name, self.location, env, args, param_names=[base_name(x[0]) for x in vars]) ret = self.do_call(loc, vars, body, args, env) + after_function(name, self.location, env, args, ret, param_names=[base_name(x[0]) for x in vars]) case GenRecFun(loc, name, [], params, returns, measure, measure_ty, body, terminates): @@ -1312,6 +1337,7 @@ def do_recursive_call(self, loc, name, fun, type_params, type_args, params, args print('call to recursive function: ' + str(fun)) print('\targs: ' + ', '.join([str(a) for a in args])) + if env.get_tracing(name): global recursion_depth recursion_depth += 1 @@ -1327,9 +1353,13 @@ def do_recursive_call(self, loc, name, fun, type_params, type_args, params, args for fun_case in cases: subst = {} if is_match(fun_case.pattern, first_arg, subst): - return do_function_call(loc, name, type_params, type_args, + from deduce_debugger import on_function, after_function + on_function(name, fun_case.location, env, args) + ret = do_function_call(loc, name, type_params, type_args, fun_case.parameters, rest_args, fun_case.body, subst, env, returns) + after_function(name, fun_case.location, env, args, ret) + return ret if is_assoc: if get_verbose(): print('not reducing recursive call to associative ' + str(fun)) @@ -4286,6 +4316,12 @@ def get_def_of_type_var(self, var): return self._def_of_type_var(self.dict, name) case _: raise Exception('get_def_of_type_var: unexpected ' + str(var)) + + def get_def_of_term_name(self, name): + if name in self.dict.keys(): # the name '=' is not in the env + return self.dict[name] + else: + return None def get_formula_of_proof_var(self, pvar): match pvar: diff --git a/deduce.py b/deduce.py index e05f706..d98d432 100644 --- a/deduce.py +++ b/deduce.py @@ -1,3 +1,4 @@ +from deduce_debugger import set_debugging from flags import * from proof_checker import check_deduce, uniquify_deduce, is_modified from abstract_syntax import init_import_directories, add_import_directory, print_theorems, get_recursive_descent, set_recursive_descent, get_uniquified_modules, add_uniquified_module, VerboseLevel @@ -147,6 +148,8 @@ def deduce_directory(directory, recursive_directories, tracing_functions): exit(0) elif argument == '--no-check-imports': set_check_imports(False) + elif argument == '--debug': + set_debugging() else: deducables.append(argument) diff --git a/deduce_debugger.py b/deduce_debugger.py new file mode 100644 index 0000000..3b8fc7b --- /dev/null +++ b/deduce_debugger.py @@ -0,0 +1,162 @@ +from abstract_syntax import * +from lark.tree import Meta + +breakpoints: set[str | Meta] = set() +stepping: bool = False +break_on_next: bool = False +last_input: list[str] = [''] +break_after: dict[object, list[int]] = {} + +def set_debugging(): + global break_on_next + break_on_next = True + +def break_at_point(loc: Meta | str): + global breakpoints + breakpoints.add(loc) + print('Breakpoint added at:', loc) + +def dont_break_on(place: Meta | str) -> bool: + global break_on_next + if break_on_next: + break_on_next = False + return False + return not stepping and place not in breakpoints + +def increment_break_after(loc): + if loc in break_after and len(break_after[loc]) > 0: + break_after[loc][0] += 1 + +def dont_break_after(loc) -> bool: + if loc not in break_after or len(break_after[loc]) == 0: + return True + if break_after[loc][0] == 0: + break_after[loc] = break_after[loc][1:] + return False + break_after[loc][0] -= 1 + return True + +def ask_for_input(loc: str, env: Env, params={}): + # TODO: Support options such as: + # break on line number + + global break_after, break_on_next, stepping + while True: + global break_on_next, last_input + user_input = input('').split(' ') + if user_input == ['']: + user_input = last_input + match user_input: + case ['break', func_name] | ['b', func_name]: + var_ty = env.get_type_of_term_var(Var(Meta(), None, env.base_to_unique(func_name), [])) + match var_ty: + case FunctionType(_, _, _, _): + break_at_point(func_name) + case _: + print("Couldn't add a breakpoint for", func_name) + continue + case ['step over'] | ['so']: + last_input = user_input + if loc not in break_after: + break_after[loc] = [0] + else: + break_after[loc].insert(0, 0) + return + case ['step'] | ['s']: + stepping = True + last_input = user_input + return + case ['continue'] | ['c']: + last_input = user_input + break_on_next = False + break_after.clear() + return + case ['print', var] | ['p', var]: + if var in params: + print(params[var]) + else: + value = env.get_value_of_term_var(Var(Meta(), None, env.base_to_unique(var), [])) + if value is None: + print('Couldn\'t find a value for:', var) + else: + print(value) + last_input = [] + case ['']: + continue + case _: + print('Unrecognized:', ' '.join(user_input)) + +def out_of_module(func: str, env: Env) -> bool: + current_module = env.get_current_module() + binding = env.get_def_of_term_name(func) + match binding: + case None: # Undefined in this module + return True + case Binding(_): + return current_module != binding.module + case _: + return False + +def on_statement(stmt, env: Env): + loc = stmt.location + increment_break_after(loc) + global stepping + if dont_break_on(loc): + return + stepping = False + print('At statement', stmt) + ask_for_input(loc, env) + +def after_statement(stmt, env: Env): + loc = stmt.location + if dont_break_after(loc): + return + + global break_on_next + break_on_next = True + +def on_function(func_name: str, loc: Meta, env: Env, rands, param_names = None): + global stepping + if out_of_module(func_name, env): + return + + func_name = base_name(func_name) + if dont_break_on(func_name) and dont_break_on(loc): + increment_break_after(func_name) + return + + stepping = False + if param_names is not None: + names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)] + params_dict = dict(zip(param_names, rands)) + else: + names = [str(x) for x in rands] + params_dict = {} + + print('At function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')') + ask_for_input(func_name, env, params_dict) + +def after_function(func_name: str, loc: Meta, env: Env, rands, return_value, param_names = None): + if out_of_module(func_name, env): + return + + func_name = base_name(func_name) + if dont_break_after(func_name): + return + global break_on_next + if param_names is not None: + names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)] + else: + names = [str(x) for x in rands] + + # print('Breaking on next cuz function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')') + break_on_next = True + # return + if param_names is not None: + names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)] + else: + names = [str(x) for x in rands] + + print('After function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')') + print('<<', return_value) + ask_for_input(func_name, env) diff --git a/proof_checker.py b/proof_checker.py index 90f25c0..4286449 100644 --- a/proof_checker.py +++ b/proof_checker.py @@ -23,8 +23,10 @@ # reduce some formulas and terms automatically. from abstract_syntax import * +from deduce_debugger import ask_for_input from error import error, incomplete_error, warning, error_header, IncompleteProof, match_failed, MatchFailed from flags import get_verbose, set_verbose, print_verbose, VerboseLevel +from deduce_debugger import on_statement, after_statement imported_modules = set() checked_modules = set() @@ -37,7 +39,7 @@ def generate_name(name): new_id = name_id name_id += 1 return ls[0] + '.' + str(new_id) - + def check_implies(loc, frm1, frm2): if get_verbose(): print('check_implies? ' + str(frm1) + ' => ' + str(frm2)) @@ -2782,6 +2784,7 @@ def find_rec_calls(name, term, env): def check_proofs(stmt, env: Env): + on_statement(stmt, env) if get_verbose(): print('\n\ncheck_proofs(' + str(stmt) + ')') match stmt: @@ -2903,6 +2906,8 @@ def check_proofs(stmt, env: Env): case _: error(stmt.location, "check_proofs: unrecognized statement:\n" + str(stmt)) + + after_statement(stmt, env) def check_deduce(ast, module_name, modified, tracing_functions): env = Env() @@ -2939,7 +2944,7 @@ def check_deduce(ast, module_name, modified, tracing_functions): if get_verbose(): for s in ast3: print(s) - + if get_verbose(): print('--------- Proof Checking ------------------------') if module_name not in checked_modules: diff --git a/rec_desc_parser.py b/rec_desc_parser.py index 71acbf7..5935b9c 100644 --- a/rec_desc_parser.py +++ b/rec_desc_parser.py @@ -5,6 +5,7 @@ from lark import Lark, Token, logger, exceptions, tree from error import * from edit_distance import closest_keyword, edit_distance +from deduce_debugger import break_at_point filename = '???' @@ -391,6 +392,12 @@ def parse_term_hi(): elif token.type == 'DEFINE': return parse_define_term() + elif token.type == 'BREAK': + advance() + break_point = parse_term_hi() + break_at_point(break_point.location) + return break_point + else: try: name = parse_identifier()