Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Deduce.lark
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ ident: IDENT -> ident
| "if" term "then" term "else" term -> conditional
| "if" term "then" term -> if_then_formula
| "%" type
| "break" term_hi -> break_before

?term_list: -> empty
| term -> single
Expand Down
46 changes: 41 additions & 5 deletions abstract_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,10 +1079,10 @@ def op_arg_str(trm, arg):
return "(" + str(arg) + ")"
return str(arg)



def do_function_call(loc, name, type_params, type_args,
params, args, body, subst, env, return_type):
from deduce_debugger import on_function, after_function
on_function(name, loc, env, args, [base_name(p) for p in params])
fast_call = False
if get_eval_all() and len(args) == 2 and isNat(args[0]) and isNat(args[1]):
op = base_name(name)
Expand Down Expand Up @@ -1164,9 +1164,28 @@ def do_function_call(loc, name, type_params, type_args,
global recursion_depth
print('<' * recursion_depth, str(ret))
recursion_depth -= 1


after_function(name, loc, env, args, ret)
return explicit_term_inst(ret)

@dataclass
class Breakpoint(Term):
point: Term

def __str__(self):
return str(self.point)

def uniquify(self, env):
self.point.uniquify(env)
from deduce_debugger import break_at_point
break_at_point(self.point.location)

def reduce(self, env):
from deduce_debugger import on_statement, after_statement
on_statement(self.point.location, env)
ret = self.point.reduce(env)
after_statement(self.point.location, env)
return ret

@dataclass
class Call(Term):
Expand Down Expand Up @@ -1219,7 +1238,7 @@ def __eq__(self, other):
#print(str(self) + ' =? ' + str(other) + ' = ' + str(result))
return result

def reduce(self, env):
def reduce(self, env):
fun = self.rator.reduce(env)
if get_eval_all():
is_assoc = False
Expand Down Expand Up @@ -1247,10 +1266,16 @@ def reduce(self, env):
ret.type_args = self.type_args

case Lambda(loc, ty, vars, body):
from deduce_debugger import on_function, after_function
name = rator_name(self.rator)
if hasattr(fun, 'env'):
on_function(name, self.location, fun.env, args, param_names=[base_name(x[0]) for x in vars])
ret = self.do_call(loc, vars, body, args, fun.env)
after_function(name, self.location, fun.env, args, ret, param_names=[base_name(x[0]) for x in vars])
else:
on_function(name, self.location, env, args, param_names=[base_name(x[0]) for x in vars])
ret = self.do_call(loc, vars, body, args, env)
after_function(name, self.location, env, args, ret, param_names=[base_name(x[0]) for x in vars])

case GenRecFun(loc, name, [], params, returns, measure, measure_ty,
body, terminates):
Expand Down Expand Up @@ -1312,6 +1337,7 @@ def do_recursive_call(self, loc, name, fun, type_params, type_args, params, args
print('call to recursive function: ' + str(fun))
print('\targs: ' + ', '.join([str(a) for a in args]))


if env.get_tracing(name):
global recursion_depth
recursion_depth += 1
Expand All @@ -1327,9 +1353,13 @@ def do_recursive_call(self, loc, name, fun, type_params, type_args, params, args
for fun_case in cases:
subst = {}
if is_match(fun_case.pattern, first_arg, subst):
return do_function_call(loc, name, type_params, type_args,
from deduce_debugger import on_function, after_function
on_function(name, fun_case.location, env, args)
ret = do_function_call(loc, name, type_params, type_args,
fun_case.parameters, rest_args,
fun_case.body, subst, env, returns)
after_function(name, fun_case.location, env, args, ret)
return ret
if is_assoc:
if get_verbose():
print('not reducing recursive call to associative ' + str(fun))
Expand Down Expand Up @@ -4286,6 +4316,12 @@ def get_def_of_type_var(self, var):
return self._def_of_type_var(self.dict, name)
case _:
raise Exception('get_def_of_type_var: unexpected ' + str(var))

def get_def_of_term_name(self, name):
if name in self.dict.keys(): # the name '=' is not in the env
return self.dict[name]
else:
return None

def get_formula_of_proof_var(self, pvar):
match pvar:
Expand Down
3 changes: 3 additions & 0 deletions deduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from deduce_debugger import set_debugging
from flags import *
from proof_checker import check_deduce, uniquify_deduce, is_modified
from abstract_syntax import init_import_directories, add_import_directory, print_theorems, get_recursive_descent, set_recursive_descent, get_uniquified_modules, add_uniquified_module, VerboseLevel
Expand Down Expand Up @@ -147,6 +148,8 @@ def deduce_directory(directory, recursive_directories, tracing_functions):
exit(0)
elif argument == '--no-check-imports':
set_check_imports(False)
elif argument == '--debug':
set_debugging()
else:
deducables.append(argument)

Expand Down
162 changes: 162 additions & 0 deletions deduce_debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from abstract_syntax import *
from lark.tree import Meta

breakpoints: set[str | Meta] = set()
stepping: bool = False
break_on_next: bool = False
last_input: list[str] = ['']
break_after: dict[object, list[int]] = {}

def set_debugging():
global break_on_next
break_on_next = True

def break_at_point(loc: Meta | str):
global breakpoints
breakpoints.add(loc)
print('Breakpoint added at:', loc)

def dont_break_on(place: Meta | str) -> bool:
global break_on_next
if break_on_next:
break_on_next = False
return False
return not stepping and place not in breakpoints

def increment_break_after(loc):
if loc in break_after and len(break_after[loc]) > 0:
break_after[loc][0] += 1

def dont_break_after(loc) -> bool:
if loc not in break_after or len(break_after[loc]) == 0:
return True
if break_after[loc][0] == 0:
break_after[loc] = break_after[loc][1:]
return False
break_after[loc][0] -= 1
return True

def ask_for_input(loc: str, env: Env, params={}):
# TODO: Support options such as:
# break on line number

global break_after, break_on_next, stepping
while True:
global break_on_next, last_input
user_input = input('').split(' ')
if user_input == ['']:
user_input = last_input
match user_input:
case ['break', func_name] | ['b', func_name]:
var_ty = env.get_type_of_term_var(Var(Meta(), None, env.base_to_unique(func_name), []))
match var_ty:
case FunctionType(_, _, _, _):
break_at_point(func_name)
case _:
print("Couldn't add a breakpoint for", func_name)
continue
case ['step over'] | ['so']:
last_input = user_input
if loc not in break_after:
break_after[loc] = [0]
else:
break_after[loc].insert(0, 0)
return
case ['step'] | ['s']:
stepping = True
last_input = user_input
return
case ['continue'] | ['c']:
last_input = user_input
break_on_next = False
break_after.clear()
return
case ['print', var] | ['p', var]:
if var in params:
print(params[var])
else:
value = env.get_value_of_term_var(Var(Meta(), None, env.base_to_unique(var), []))
if value is None:
print('Couldn\'t find a value for:', var)
else:
print(value)
last_input = []
case ['']:
continue
case _:
print('Unrecognized:', ' '.join(user_input))

def out_of_module(func: str, env: Env) -> bool:
current_module = env.get_current_module()
binding = env.get_def_of_term_name(func)
match binding:
case None: # Undefined in this module
return True
case Binding(_):
return current_module != binding.module
case _:
return False

def on_statement(stmt, env: Env):
loc = stmt.location
increment_break_after(loc)
global stepping
if dont_break_on(loc):
return
stepping = False
print('At statement', stmt)
ask_for_input(loc, env)

def after_statement(stmt, env: Env):
loc = stmt.location
if dont_break_after(loc):
return

global break_on_next
break_on_next = True

def on_function(func_name: str, loc: Meta, env: Env, rands, param_names = None):
global stepping
if out_of_module(func_name, env):
return

func_name = base_name(func_name)
if dont_break_on(func_name) and dont_break_on(loc):
increment_break_after(func_name)
return

stepping = False
if param_names is not None:
names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)]
params_dict = dict(zip(param_names, rands))
else:
names = [str(x) for x in rands]
params_dict = {}

print('At function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')')
ask_for_input(func_name, env, params_dict)

def after_function(func_name: str, loc: Meta, env: Env, rands, return_value, param_names = None):
if out_of_module(func_name, env):
return

func_name = base_name(func_name)
if dont_break_after(func_name):
return
global break_on_next
if param_names is not None:
names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)]
else:
names = [str(x) for x in rands]

# print('Breaking on next cuz function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')')
break_on_next = True
# return
if param_names is not None:
names = [str(p) + ':' + str(a) for p,a in zip(param_names, rands)]
else:
names = [str(x) for x in rands]

print('After function ' + str(loc.line) + ':' + str(loc.column), func_name + '(' + ', '.join(names) + ')')
print('<<', return_value)
ask_for_input(func_name, env)
9 changes: 7 additions & 2 deletions proof_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
# reduce some formulas and terms automatically.

from abstract_syntax import *
from deduce_debugger import ask_for_input
from error import error, incomplete_error, warning, error_header, IncompleteProof, match_failed, MatchFailed
from flags import get_verbose, set_verbose, print_verbose, VerboseLevel
from deduce_debugger import on_statement, after_statement

imported_modules = set()
checked_modules = set()
Expand All @@ -37,7 +39,7 @@ def generate_name(name):
new_id = name_id
name_id += 1
return ls[0] + '.' + str(new_id)

def check_implies(loc, frm1, frm2):
if get_verbose():
print('check_implies? ' + str(frm1) + ' => ' + str(frm2))
Expand Down Expand Up @@ -2782,6 +2784,7 @@ def find_rec_calls(name, term, env):


def check_proofs(stmt, env: Env):
on_statement(stmt, env)
if get_verbose():
print('\n\ncheck_proofs(' + str(stmt) + ')')
match stmt:
Expand Down Expand Up @@ -2903,6 +2906,8 @@ def check_proofs(stmt, env: Env):

case _:
error(stmt.location, "check_proofs: unrecognized statement:\n" + str(stmt))

after_statement(stmt, env)

def check_deduce(ast, module_name, modified, tracing_functions):
env = Env()
Expand Down Expand Up @@ -2939,7 +2944,7 @@ def check_deduce(ast, module_name, modified, tracing_functions):
if get_verbose():
for s in ast3:
print(s)

if get_verbose():
print('--------- Proof Checking ------------------------')
if module_name not in checked_modules:
Expand Down
7 changes: 7 additions & 0 deletions rec_desc_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lark import Lark, Token, logger, exceptions, tree
from error import *
from edit_distance import closest_keyword, edit_distance
from deduce_debugger import break_at_point

filename = '???'

Expand Down Expand Up @@ -391,6 +392,12 @@ def parse_term_hi():
elif token.type == 'DEFINE':
return parse_define_term()

elif token.type == 'BREAK':
advance()
break_point = parse_term_hi()
break_at_point(break_point.location)
return break_point

else:
try:
name = parse_identifier()
Expand Down