From 6582dcc5e5a797c1f88c2e30fded2b3a9ae3eafd Mon Sep 17 00:00:00 2001 From: Vijval9 Date: Wed, 18 Feb 2026 10:14:37 +0000 Subject: [PATCH] Added locks for safer async execution --- interwhen/monitors/earlyStopping.py | 116 +++++++++++++++--------- interwhen/monitors/k_stable.py | 27 +++--- interwhen/monitors/stepVerifier.py | 133 ++++++++++++++++------------ 3 files changed, 164 insertions(+), 112 deletions(-) diff --git a/interwhen/monitors/earlyStopping.py b/interwhen/monitors/earlyStopping.py index 7831982..11f07e6 100644 --- a/interwhen/monitors/earlyStopping.py +++ b/interwhen/monitors/earlyStopping.py @@ -6,6 +6,7 @@ from .base import VerifyMonitor from interwhen.utils.EAT_helper import compute_entropy, exponential_moving_average, exponential_moving_variance from interwhen.utils.DEER_helper import stream_and_compute_geom_mean +import gc class EATMonitor(VerifyMonitor): @@ -33,12 +34,27 @@ def __init__(self, name, model_name, alpha=0.2, delta=0.0001, ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() + # State tracking self.entropy = [] self.ema_means = [] self.ema_vars = [] self.exit_point = None + def reset(self): + """Reset monitor state for a new problem without reloading the model.""" + self.entropy = [] + self.ema_means = [] + self.ema_vars = [] + self.exit_point = None + gc.collect() + try: + torch.cuda.empty_cache() + except Exception as e: + print("Error while emptying cuda cache: ",e) + async def _verify(self, generated_text, token_index): """ Core verification logic using entropy. @@ -47,30 +63,26 @@ async def _verify(self, generated_text, token_index): # We append this tail so that we can compute entropy for next token (answer) partial_answer = (generated_text + "\n\n" + "\n\n" + 'Final answer is \\boxed{') - entropy_2 = compute_entropy( - self.hf_model, - self.tokenizer, - partial_answer, - ) - - self.entropy.append(entropy_2) - ema_average = exponential_moving_average(self.entropy, self.alpha) - ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0) - - self.ema_means.append(ema_average[-1]) - self.ema_vars.append(ema_variance[-1]) - - # Early stopping not triggered unless min_steps number of steps have been processed - if len(self.entropy) < self.min_steps: + entropy_2 = await asyncio.to_thread(compute_entropy, self.hf_model, self.tokenizer, partial_answer) + async with self.lock: + self.entropy.append(entropy_2) + ema_average = exponential_moving_average(self.entropy, self.alpha) + ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0) + + self.ema_means.append(ema_average[-1]) + self.ema_vars.append(ema_variance[-1]) + + # Early stopping not triggered unless min_steps number of steps have been processed + if len(self.entropy) < self.min_steps: + return (True, None, token_index) + + # Intervene if variance is below threshold + if ema_variance[-1] < self.delta: + self.exit_point = len(self.entropy) + # Return False to trigger early stop + return (False, generated_text, token_index) + return (True, None, token_index) - - # Intervene if variance is below threshold - if ema_variance[-1] < self.delta: - self.exit_point = len(self.entropy) - # Return False to trigger early stop - return (False, generated_text, token_index) - - return (True, None, token_index) async def verify(self, step, token_index, event, event_info): """ @@ -82,20 +94,20 @@ async def verify(self, step, token_index, event, event_info): return step, None # Early stop triggered - if not event.is_set(): - event_info["generated_text"] = step - event_info["feedback"] = feedback - event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix - event_info["entropy_history"] = self.entropy.copy() - event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix + event_info["entropy_history"] = self.entropy.copy() + event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None + event.set() async def fix(self, generated_text, event_info, fix_method=None): """ Appending the to force the thinking process to conclude. """ fixed_text = generated_text[:event_info['correction_index']] + "\n\n" - print("VISHAAAAAAAAAAAAAAAK"*100) return fixed_text def step_extractor(self, chunk, generated_text): @@ -128,6 +140,17 @@ def __init__(self, name, llm_server, delta=0.995, answer_start_token="", self.max_probe_steps = max_probe_steps self.answer_start_token = answer_start_token self.confidence = [] + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() + + def reset(self): + """Reset monitor state for a new problem.""" + self.confidence = [] + gc.collect() + try: + torch.cuda.empty_cache() + except Exception as e: + print("Error while emptying cuda cache: ",e) async def _verify(self, generated_text, token_index): """ @@ -137,14 +160,19 @@ async def _verify(self, generated_text, token_index): # We apppend this tail so that we can compute confidence for the answer partial_answer = (generated_text + "\n\n" + "\n\n" + 'Final answer is \\boxed{') - self.llm_server["payload"]["prompt"] = partial_answer - confidence = stream_and_compute_geom_mean(self.llm_server) - self.confidence.append(confidence) + + # Create copy to avoid mutating shared state + payload_copy = {**self.llm_server["payload"], "prompt": partial_answer} + server_copy = {**self.llm_server, "payload": payload_copy} + + confidence = await asyncio.to_thread(stream_and_compute_geom_mean, server_copy) - if confidence > self.delta: - return False, generated_text, token_index + async with self.lock: + self.confidence.append(confidence) + if confidence > self.delta: + return False, generated_text, token_index - return (True, None, token_index) + return (True, None, token_index) async def verify(self, step, token_index, event, event_info): """ @@ -156,18 +184,18 @@ async def verify(self, step, token_index, event, event_info): return step, None # Early stop triggered - if not event.is_set(): - event_info["generated_text"] = step - event_info["feedback"] = feedback - event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix - event_info["confidence_history"] = self.confidence.copy() - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix + event_info["confidence_history"] = self.confidence.copy() + event.set() async def fix(self, generated_text, event_info, fix_method=None): """ Appending to force the thinking process to conclude. """ - # Append answer prompt to conclude fixed_text = generated_text[:event_info['correction_index']] + "\n\n" return fixed_text diff --git a/interwhen/monitors/k_stable.py b/interwhen/monitors/k_stable.py index 49c1224..47f9603 100644 --- a/interwhen/monitors/k_stable.py +++ b/interwhen/monitors/k_stable.py @@ -1,5 +1,6 @@ import re from .base import VerifyMonitor +import asyncio NEGATION_WORDS = ["not", "isn't", "isnt", "no ", "cannot", "can't", "cant", "doesn't", "doesnt", "never"] @@ -26,6 +27,8 @@ def __init__(self, name, k, options, answer_start_token=""): self.options = options self.answer_start_token = answer_start_token self.stabilized_answer = None + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() def _contains_negation(self, text: str) -> bool: """Check if text contains negation words indicating uncertainty.""" @@ -198,11 +201,12 @@ async def verify(self, step, token_index, event, event_info): if is_valid: return step, None - if not event.is_set(): - event_info["generated_text"] = step - event_info["feedback"] = "" # Sliced text up to k-stable point - event_info["correction_index"] = correction_index - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = "" # Sliced text up to k-stable point + event_info["correction_index"] = correction_index + event.set() async def fix(self, generated_text, event_info, fix_method=None): """Return the sliced text up to the k-stable point.""" @@ -257,6 +261,8 @@ def __init__(self, name, k, expected_nums=None, answer_start_token=""): self.expected_nums = expected_nums self.answer_start_token = answer_start_token self.stabilized_equation = None + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() def _extract_numbers_from_expr(self, expr): """Extract all numbers (including decimals) from an expression.""" @@ -439,11 +445,12 @@ async def verify(self, chunk, token_index, event, event_info): if is_valid: return chunk, None - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = "" - event_info["correction_index"] = correction_index - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = "" + event_info["correction_index"] = correction_index + event.set() async def fix(self, generated_text, event_info, fix_method=None): """Return the sliced text up to the k-stable point.""" diff --git a/interwhen/monitors/stepVerifier.py b/interwhen/monitors/stepVerifier.py index ca5fc4b..a1fb992 100644 --- a/interwhen/monitors/stepVerifier.py +++ b/interwhen/monitors/stepVerifier.py @@ -11,6 +11,7 @@ SpatialMapZ3Solver, parse_directional_claims_from_text, extract_step2_claims, verify_spatialmap_step, format_spatialmap_feedback ) +import asyncio class StepVerifierGame24Monitor(VerifyMonitor): @@ -26,6 +27,8 @@ def __init__(self, name, answer_start_token, original_numbers, max_corrections=5 self.answer_start_token = answer_start_token self.original_numbers = [float(x) for x in original_numbers] self.max_corrections = max_corrections + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() def _count_feedback_blocks(self, text): """Count how many [VERIFIER FEEDBACK...] blocks are in the text.""" @@ -173,13 +176,14 @@ async def verify(self, chunk, token_index, event, event_info): num_corrections = self._count_feedback_blocks(chunk) if num_corrections >= self.max_corrections: max_feedback = "\nthe answer is \\boxed{no solution}" - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = max_feedback - event_info["correction_index"] = token_index - event_info["errors"] = ["Max corrections reached"] - event_info["failed_step"] = None - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = max_feedback + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["failed_step"] = None + event.set() return chunk, max_feedback @@ -205,14 +209,14 @@ async def verify(self, chunk, token_index, event, event_info): # Step has errors - generate feedback feedback = format_feedback(errors, step_num, current_available) - - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = feedback - event_info["correction_index"] = token_index - event_info["errors"] = errors - event_info["failed_step"] = step_num - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["failed_step"] = step_num + event.set() return chunk, feedback @@ -305,6 +309,8 @@ def __init__( self.exit_pos = exit_pos self.max_corrections = max_corrections self.question_type = question_type + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() @staticmethod def detect_question_type(prompt: str) -> str: @@ -540,13 +546,14 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): num_corrections = self._count_feedback_blocks(chunk) if num_corrections >= self.max_corrections: max_feedback = "\nthe answer is \\boxed{no solution}" - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = max_feedback - event_info["correction_index"] = token_index - event_info["errors"] = ["Max corrections reached"] - event_info["failed_step"] = None - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = max_feedback + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["failed_step"] = None + event.set() return chunk, max_feedback @@ -563,13 +570,14 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): locate_valid, locate_errors, locate_found = self._check_locate_section(chunk) if locate_found and not locate_valid: feedback = format_locate_feedback(locate_errors) - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = feedback - event_info["correction_index"] = token_index - event_info["errors"] = locate_errors - event_info["failed_step"] = 0 # LOCATE section - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = locate_errors + event_info["failed_step"] = 0 # LOCATE section + event.set() return chunk, feedback return chunk, None @@ -593,13 +601,14 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): # Step has errors - generate feedback feedback = format_maze_feedback(errors, step_num) - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = feedback - event_info["correction_index"] = token_index - event_info["errors"] = errors - event_info["failed_step"] = step_num - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["failed_step"] = step_num + event.set() return chunk, feedback @@ -620,13 +629,14 @@ async def _verify_relative_position(self, chunk: str, token_index: int, event, e locate_valid, locate_errors, locate_found = self._check_locate_section(chunk) if locate_found and not locate_valid: feedback = format_locate_feedback(locate_errors) - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = feedback - event_info["correction_index"] = token_index - event_info["errors"] = locate_errors - event_info["failed_step"] = 0 - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = locate_errors + event_info["failed_step"] = 0 + event.set() return chunk, feedback # For relative_position, we don't verify the final Yes/No answer @@ -811,6 +821,9 @@ def __init__( # Track verified claims to avoid re-checking self.verified_claims: Set[Tuple[str, str, str]] = set() + + # Instantiate Lock for safer async execution + self.lock = asyncio.Lock() @classmethod def from_prompt( @@ -866,6 +879,7 @@ def _extract_new_claims(self, chunk: str) -> List[Dict]: # Filter to only new claims (not yet verified) new_claims = [] + for claim in all_claims: claim_key = (claim['A'], claim['direction'], claim['B']) if claim_key not in self.verified_claims: @@ -890,13 +904,14 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): num_corrections = self._count_feedback_blocks(chunk) if num_corrections >= self.max_corrections: max_feedback = "\nthe answer is \\boxed{no solution}" - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = max_feedback - event_info["correction_index"] = token_index - event_info["errors"] = ["Max corrections reached"] - event_info["failed_step"] = None - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = max_feedback + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["failed_step"] = None + event.set() return chunk, max_feedback # Extract new claims to verify @@ -913,19 +928,21 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): ) # Mark as verified (whether valid or not) - self.verified_claims.add(claim_key) + async with self.lock: + self.verified_claims.add(claim_key) if not is_valid: # Contradiction found - generate feedback feedback = format_spatialmap_feedback(errors, claim) - if not event.is_set(): - event_info["generated_text"] = chunk - event_info["feedback"] = feedback - event_info["correction_index"] = token_index - event_info["errors"] = errors - event_info["failed_step"] = claim - event.set() + async with self.lock: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["failed_step"] = claim + event.set() return chunk, feedback