Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 72 additions & 44 deletions interwhen/monitors/earlyStopping.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the latency difference after adding the lock across all the monitors ?

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .base import VerifyMonitor
from interwhen.utils.EAT_helper import compute_entropy, exponential_moving_average, exponential_moving_variance
from interwhen.utils.DEER_helper import stream_and_compute_geom_mean
import gc


class EATMonitor(VerifyMonitor):
Expand Down Expand Up @@ -33,12 +34,27 @@ def __init__(self, name, model_name, alpha=0.2, delta=0.0001,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

# Instantiate Lock for safer async execution
self.lock = asyncio.Lock()

# State tracking
self.entropy = []
self.ema_means = []
self.ema_vars = []
self.exit_point = None

def reset(self):
"""Reset monitor state for a new problem without reloading the model."""
self.entropy = []
self.ema_means = []
self.ema_vars = []
self.exit_point = None
gc.collect()
try:
torch.cuda.empty_cache()
except Exception as e:
print("Error while emptying cuda cache: ",e)
Comment on lines +46 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not being used any more, right? If not, we can remove this


async def _verify(self, generated_text, token_index):
"""
Core verification logic using entropy.
Expand All @@ -47,30 +63,26 @@ async def _verify(self, generated_text, token_index):

# We append this tail so that we can compute entropy for next token (answer)
partial_answer = (generated_text + "\n\n</think>" + "\n\n" + 'Final answer is \\boxed{')
entropy_2 = compute_entropy(
self.hf_model,
self.tokenizer,
partial_answer,
)

self.entropy.append(entropy_2)
ema_average = exponential_moving_average(self.entropy, self.alpha)
ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0)

self.ema_means.append(ema_average[-1])
self.ema_vars.append(ema_variance[-1])

# Early stopping not triggered unless min_steps number of steps have been processed
if len(self.entropy) < self.min_steps:
entropy_2 = await asyncio.to_thread(compute_entropy, self.hf_model, self.tokenizer, partial_answer)
async with self.lock:
self.entropy.append(entropy_2)
ema_average = exponential_moving_average(self.entropy, self.alpha)
ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0)

self.ema_means.append(ema_average[-1])
self.ema_vars.append(ema_variance[-1])

# Early stopping not triggered unless min_steps number of steps have been processed
if len(self.entropy) < self.min_steps:
return (True, None, token_index)

# Intervene if variance is below threshold
if ema_variance[-1] < self.delta:
self.exit_point = len(self.entropy)
# Return False to trigger early stop
return (False, generated_text, token_index)

return (True, None, token_index)

# Intervene if variance is below threshold
if ema_variance[-1] < self.delta:
self.exit_point = len(self.entropy)
# Return False to trigger early stop
return (False, generated_text, token_index)

return (True, None, token_index)

async def verify(self, step, token_index, event, event_info):
"""
Expand All @@ -82,20 +94,20 @@ async def verify(self, step, token_index, event, event_info):
return step, None

# Early stop triggered
if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = feedback
event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix
event_info["entropy_history"] = self.entropy.copy()
event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None
event.set()
async with self.lock:
if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = feedback
event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix
event_info["entropy_history"] = self.entropy.copy()
event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None
event.set()

async def fix(self, generated_text, event_info, fix_method=None):
"""
Appending the </think> to force the thinking process to conclude.
"""
fixed_text = generated_text[:event_info['correction_index']] + "\n\n</think>"
print("VISHAAAAAAAAAAAAAAAK"*100)
return fixed_text

def step_extractor(self, chunk, generated_text):
Expand Down Expand Up @@ -128,6 +140,17 @@ def __init__(self, name, llm_server, delta=0.995, answer_start_token="</think>",
self.max_probe_steps = max_probe_steps
self.answer_start_token = answer_start_token
self.confidence = []
# Instantiate Lock for safer async execution
self.lock = asyncio.Lock()

def reset(self):
"""Reset monitor state for a new problem."""
self.confidence = []
gc.collect()
try:
torch.cuda.empty_cache()
except Exception as e:
print("Error while emptying cuda cache: ",e)
Comment on lines +146 to +153
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here too


async def _verify(self, generated_text, token_index):
"""
Expand All @@ -137,14 +160,19 @@ async def _verify(self, generated_text, token_index):

# We apppend this tail so that we can compute confidence for the answer
partial_answer = (generated_text + "\n\n</think>" + "\n\n" + 'Final answer is \\boxed{')
self.llm_server["payload"]["prompt"] = partial_answer
confidence = stream_and_compute_geom_mean(self.llm_server)
self.confidence.append(confidence)

# Create copy to avoid mutating shared state
payload_copy = {**self.llm_server["payload"], "prompt": partial_answer}
server_copy = {**self.llm_server, "payload": payload_copy}

confidence = await asyncio.to_thread(stream_and_compute_geom_mean, server_copy)

if confidence > self.delta:
return False, generated_text, token_index
async with self.lock:
self.confidence.append(confidence)
if confidence > self.delta:
return False, generated_text, token_index

return (True, None, token_index)
return (True, None, token_index)

async def verify(self, step, token_index, event, event_info):
"""
Expand All @@ -156,18 +184,18 @@ async def verify(self, step, token_index, event, event_info):
return step, None

# Early stop triggered
if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = feedback
event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix
event_info["confidence_history"] = self.confidence.copy()
event.set()
async with self.lock:
if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = feedback
event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix
event_info["confidence_history"] = self.confidence.copy()
event.set()

async def fix(self, generated_text, event_info, fix_method=None):
"""
Appending </think> to force the thinking process to conclude.
"""
# Append answer prompt to conclude
fixed_text = generated_text[:event_info['correction_index']] + "\n\n</think>"
return fixed_text

Expand Down
27 changes: 17 additions & 10 deletions interwhen/monitors/k_stable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from .base import VerifyMonitor
import asyncio

NEGATION_WORDS = ["not", "isn't", "isnt", "no ", "cannot", "can't", "cant",
"doesn't", "doesnt", "never"]
Expand All @@ -26,6 +27,8 @@ def __init__(self, name, k, options, answer_start_token="</think>"):
self.options = options
self.answer_start_token = answer_start_token
self.stabilized_answer = None
# Instantiate Lock for safer async execution
self.lock = asyncio.Lock()

def _contains_negation(self, text: str) -> bool:
"""Check if text contains negation words indicating uncertainty."""
Expand Down Expand Up @@ -198,11 +201,12 @@ async def verify(self, step, token_index, event, event_info):
if is_valid:
return step, None

if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = "</think>" # Sliced text up to k-stable point
event_info["correction_index"] = correction_index
event.set()
async with self.lock:
if not event.is_set():
event_info["generated_text"] = step
event_info["feedback"] = "</think>" # Sliced text up to k-stable point
event_info["correction_index"] = correction_index
event.set()

async def fix(self, generated_text, event_info, fix_method=None):
"""Return the sliced text up to the k-stable point."""
Expand Down Expand Up @@ -257,6 +261,8 @@ def __init__(self, name, k, expected_nums=None, answer_start_token="</think>"):
self.expected_nums = expected_nums
self.answer_start_token = answer_start_token
self.stabilized_equation = None
# Instantiate Lock for safer async execution
self.lock = asyncio.Lock()

def _extract_numbers_from_expr(self, expr):
"""Extract all numbers (including decimals) from an expression."""
Expand Down Expand Up @@ -439,11 +445,12 @@ async def verify(self, chunk, token_index, event, event_info):
if is_valid:
return chunk, None

if not event.is_set():
event_info["generated_text"] = chunk
event_info["feedback"] = "</think>"
event_info["correction_index"] = correction_index
event.set()
async with self.lock:
if not event.is_set():
event_info["generated_text"] = chunk
event_info["feedback"] = "</think>"
event_info["correction_index"] = correction_index
event.set()

async def fix(self, generated_text, event_info, fix_method=None):
"""Return the sliced text up to the k-stable point."""
Expand Down
Loading