diff --git a/.gitignore b/.gitignore index b7faf40..d7b3763 100644 --- a/.gitignore +++ b/.gitignore @@ -198,6 +198,8 @@ cython_debug/ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files +.*/Outputs_TTS/ +Outputs_TTS/ .cursorignore .cursorindexingignore diff --git a/examples/TTSwithVerification/MULTIPROCESS_README.md b/examples/TTSwithVerification/MULTIPROCESS_README.md new file mode 100644 index 0000000..4a64612 --- /dev/null +++ b/examples/TTSwithVerification/MULTIPROCESS_README.md @@ -0,0 +1,67 @@ +# Multi-Process vLLM Setup for Best-of-K Baseline + +This directory contains scripts and code for running the best-of-K baseline with multi-process vLLM serving. + +## Setup + +### 1. Start vLLM with 4 processes (2 GPUs each) + +```bash +bash start_vllm_multiprocess.sh +``` + +This launches 4 vLLM OpenAI-compatible API servers: +- **Process 1**: GPUs 0-1, Port 8000 +- **Process 2**: GPUs 2-3, Port 8001 +- **Process 3**: GPUs 4-5, Port 8002 +- **Process 4**: GPUs 6-7, Port 8003 + +Each process uses `tensor-parallel-size 2` for distributed inference. + +### 2. Run the baseline + +In a separate terminal: + +```bash +# Test with 1 example +python bestofk_baseline.py --task game24 --num_examples 1 --k 4 --use_critic + +# Run on maze dataset +python bestofk_baseline.py --task maze --num_examples 10 --k 4 + +# Run on spatialmap dataset +python bestofk_baseline.py --task spatialmap --num_examples 5 --k 4 +``` + +Or use the test script: +```bash +bash run_multiprocess_test.sh game24 5 +``` + +## Load Balancing + +- Requests are distributed **round-robin** across the 4 vLLM instances +- Each generation request goes to the next available port (8000 → 8001 → 8002 → 8003 → 8000 ...) +- Critic evaluation requests use separate round-robin tracking (independent counter) +- This ensures even load distribution across all 4 GPU pairs + +## Stopping vLLM + +```bash +pkill -9 -f "vllm.entrypoints.openai.api_server" +``` + +## Configuration + +Edit `start_vllm_multiprocess.sh` to change: +- `MODEL`: Model name (default: `Qwen/QwQ-32B`) +- `MAX_TOKENS`: Maximum sequence length (default: 8192) +- `GPU_MEMORY`: GPU memory utilization (default: 0.4) +- `TENSOR_PARALLEL`: Must be ≤ 2 for this 8-GPU setup + +## Benefits + +- **Better throughput**: 4 independent processes handle requests in parallel +- **Fault tolerance**: If one process crashes, others continue +- **GPU utilization**: Balanced load across all 8 GPUs (2 GPUs per process) +- **Reduced latency**: Each process has dedicated GPU resources diff --git a/examples/TTSwithVerification/README.md b/examples/TTSwithVerification/README.md index ed7fbb4..a2c52dd 100644 --- a/examples/TTSwithVerification/README.md +++ b/examples/TTSwithVerification/README.md @@ -156,6 +156,39 @@ The Z3 solver handles diagonal directions (`Northwest`, `Northeast`, `Southwest` --- +# Best-of-K Baseline + +A simple best-of-K baseline that generates K independent reasoning traces per example and selects the best based on: +1. **Ground-truth matching** (default): Greedy selection of first correct answer among K samples +2. **Critic model evaluation** (optional): Use a separate critic LLM to evaluate correctness without access to ground truth + +This baseline demonstrates that with sufficient sampling, even simple CoT can achieve good performance. + +## Usage + +```bash +# Best-of-K with ground-truth evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 --use_critic --critic_model Qwen/Qwen3-30B-A3B-Thinking-2507 --critic_port 8001 +``` + +### Parameters + +| Argument | Description | Default | +|----------|-------------|---------| +| `--task` | Task: `game24`, `maze`, or `spatialmap` | required | +| `--k` | Number of samples per example | `4` | +| `--use_critic` | Use critic model for evaluation instead of ground truth | `False` | +| `--critic_model` | Model to use for critic evaluation | MAIN_MODEL | +| `--critic_port` | vLLM server port for critic model | `8001` | +| `--num_examples`, `-n` | Number of examples to run | varies | +| `--main_model` | Model for generation | `Qwen/Qwen3-30B-A3B-Thinking-2507` | +| `--port` | vLLM server port for main model | `8000` | + +--- + ## Example Scripts Each script runs a full evaluation: loading a dataset, building structured prompts, running inference with step verification, and computing accuracy/token statistics. @@ -169,6 +202,14 @@ python ./examples/TTSwithVerification/maze_stepverifier.py -n 1 # SpatialMap with step verification python ./examples/TTSwithVerification/spatialmap_stepverifier.py -n 1 + +# Best-of-K baseline (standard CoT, no monitors) +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task maze -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task spatialmap -n 1 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 --use_critic ``` ### Common arguments diff --git a/examples/TTSwithVerification/bestofk_baseline.py b/examples/TTSwithVerification/bestofk_baseline.py new file mode 100644 index 0000000..9b52206 --- /dev/null +++ b/examples/TTSwithVerification/bestofk_baseline.py @@ -0,0 +1,846 @@ +import argparse +import asyncio +import json +import logging +import os +import re +import sys +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from interwhen import stream_completion + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/QwQ-32B" +# Multi-process vLLM configuration +VLLM_PORTS = [8000, 8001, 8002, 8003] # 4 instances with tensor-parallel-size 2 each +REQUEST_COUNTER = {"main": 0, "critic": 0} # Track request count for round-robin load balancing + + +logger = logging.getLogger(__name__) + + +@contextmanager +def suppress_output(): + """Context manager to suppress stdout and stderr.""" + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +@dataclass +class SampleResult: + output: str + correct: bool + extracted: Optional[str] + message: str + tokens: int + critic_correct: Optional[bool] = None + critic_feedback: Optional[str] = None + + +def get_model_short_name(model_name: str) -> str: + short_name = model_name.split("/")[-1] + return short_name.replace(" ", "_").replace(":", "-") + + +def get_next_port(server_type: str = "main") -> int: + """Get next vLLM port in round-robin fashion.""" + global REQUEST_COUNTER + port = VLLM_PORTS[REQUEST_COUNTER[server_type] % len(VLLM_PORTS)] + REQUEST_COUNTER[server_type] += 1 + return port + + +def get_output_dirs(task: str, main_model: str, use_critic: bool, critic_early_stop: bool, base_dir: str = "../../b-pchanda/Outputs_TTS/BestOfKResults"): + model_short_name = get_model_short_name(main_model) + critic_status = "on" if use_critic else "off" + earlystop_status = "on" if critic_early_stop else "off" + output_base = os.path.join(base_dir, task, model_short_name, f"critic_{critic_status}", f"earlystop_{earlystop_status}") + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + "critic": os.path.join(output_base, "Critic_output") if use_critic else None, + } + for dir_path in dirs.values(): + if dir_path: + os.makedirs(dir_path, exist_ok=True) + return dirs + + +def init_llm_server(model_name, max_tokens=32768, port=8000, temperature=0.6, seed=42): + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": model_name, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample": True, + "temperature": temperature, + "stream": False, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed": seed, + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + + +def count_tokens(text: str, tokenizer) -> int: + """Count tokens in text, with fallback to character count.""" + try: + if not text or len(text.strip()) == 0: + return 0 + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + except Exception as e: + logger.warning(f"Tokenization failed: {e}, using character count estimate") + # Rough estimate: ~4 characters per token + return max(1, len(text) // 4) + + +def save_outputs(idx: int, outputs: List[SampleResult], best_idx: int, output_dir: str): + os.makedirs(output_dir, exist_ok=True) + filepath = os.path.join(output_dir, f"output_{idx}.txt") + with open(filepath, "w", encoding="utf-8") as f: + f.write(f"BEST_INDEX={best_idx}\n") + for i, result in enumerate(outputs): + f.write("\n" + "=" * 80 + "\n") + f.write(f"SAMPLE {i}\n") + f.write(f"CORRECT={result.correct}\n") + f.write(f"CRITIC_CORRECT={result.critic_correct}\n") + f.write(f"EXTRACTED={result.extracted}\n") + f.write(f"TOKENS={result.tokens}\n") + f.write(f"MESSAGE={result.message}\n") + if result.critic_feedback: + f.write(f"CRITIC_FEEDBACK={result.critic_feedback}\n") + f.write("\n") + f.write(result.output) + f.write("\n") + # logger.info(f"Saved outputs to {filepath}") + + +# --------------------- Game24 helpers --------------------- + +def build_game24_prompt(nums): + a, b, c, d = nums + boxed = r"\\boxed{}" + base_prompt = f""" +You are solving the Game of 24. + +You are given four numbers: {a}, {b}, {c}, {d} + +Your job is to produce a valid arithmetic expression using: +- ALL four numbers exactly once +- ONLY +, -, *, / +- The expression must evaluate to exactly 24. + +Please reason step by step, and put your final answer containing only the expression within {boxed}. +""".strip() + return base_prompt + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +# --------------------- Maze/SpatialMap helpers --------------------- + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + # Last resort: look for any standalone A, B, C, or D + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def build_full_prompt(task, example, nums=None): + if task == "game24": + prompt = build_game24_prompt(nums) + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + if task == "maze": + system_prompt, user_prompt = build_maze_prompt(example) + else: + system_prompt, user_prompt = build_spatialmap_prompt(example) + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return range(start, end) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return np.linspace(0, dataset_len - 1, args.num_examples, dtype=int) + # Default: use full range + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return range(start, end) + + +def batch_generate_samples(prompt, llm_server, k, seed, quiet=True): + """Generate k samples using vLLM batch processing via API across multiple instances.""" + payload_template = llm_server["payload"].copy() + headers = llm_server["headers"] + + # Create k requests with different seeds + batch_payloads = [] + for i in range(k): + payload = payload_template.copy() + payload["prompt"] = prompt + payload["seed"] = seed + i + batch_payloads.append(payload) + + # Send requests to vLLM instances in parallel (true concurrency) + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for seed {seed + idx} on {url}: {resp.status} - {text[:200]}") + return idx, "" + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for seed {seed + idx} on {url}") + return idx, "" + except Exception as e: + logger.warning(f"Batch generation failed for seed {seed + idx} on {url}: {e}") + return idx, "" + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + if isinstance(choice, dict): + output_text = choice.get("text") or choice.get("message", {}).get("content", "") + else: + output_text = str(choice) + if output_text and len(output_text.strip()) > 0: + return idx, output_text + logger.warning(f"Empty output for seed {seed + idx} on {url}") + return idx, "" + + logger.warning(f"No choices in response for seed {seed + idx} on {url}: {result.keys() if isinstance(result, dict) else type(result)}") + return idx, "" + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for i, payload in enumerate(batch_payloads): + port = get_next_port(server_type="main") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, i, url, payload))) + results = await asyncio.gather(*tasks) + return results + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + outputs = [""] * k + for idx, output_text in results: + outputs[idx] = output_text + if output_text and not quiet: + print(f"[Generated sample {idx}] {len(output_text)} chars, {len(output_text.split())} words") + + return outputs + + +# --------------------- Critic model helpers --------------------- + +def build_game24_critic_prompt(nums, reasoning_output): + """Build critic prompt to evaluate Game of 24 solution and provide reasoning.""" + return f"""You are a math verifier. Evaluate the following Game of 24 solution. + +Numbers: {nums} +Target: 24 + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does it use ALL four numbers exactly once? +2. Does each step follow correct arithmetic? +3. Does the final expression evaluate to exactly 24? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and how to fix it. +""" + + +def build_mcq_critic_prompt(task, task_description, reasoning_output): + """Build critic prompt to evaluate MCQ solution and provide reasoning.""" + task_name = "Maze" if task == "maze" else "Spatial Reasoning" + return f"""You are an expert {task_name} verifier. Evaluate the following solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify the correctness of the step-by-step reasoning and final answer. + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and suggest the correct approach. +""" + + +def batch_evaluate_with_critic(outputs_df, task, example, critic_llm_server, tokenizer, nums=None, quiet=True): + """Batch evaluate outputs using vLLM API across multiple instances. Outputs_df should have columns: 'output', 'seed_idx'""" + payload_template = critic_llm_server["payload"].copy() + headers = critic_llm_server["headers"] + + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for critic sample {idx} on {url}: {resp.status} - {text[:200]}") + return idx, "", False + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for critic sample {idx} on {url}") + return idx, "", False + except Exception as e: + logger.warning(f"Critic evaluation failed for sample {idx} on {url}: {e}") + return idx, "", False + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + critic_output = choice.get("text") or choice.get("message", {}).get("content", "") + else: + critic_output = "" + + is_correct = "CORRECT" in critic_output.upper() + reasoning = "" + if "REASONING:" in critic_output: + reasoning = critic_output.split("REASONING:", 1)[1].strip() + elif "VERDICT:" not in critic_output: + reasoning = critic_output + + return idx, reasoning, is_correct + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for idx, row in outputs_df.iterrows(): + output_text = row["output"] + if task == "game24": + critic_prompt = build_game24_critic_prompt(nums, output_text) + else: + if task == "maze": + _, task_desc = build_maze_prompt(example) + else: + _, task_desc = build_spatialmap_prompt(example) + critic_prompt = build_mcq_critic_prompt(task, task_desc, output_text) + + critic_system = "You are a strict academic verifier." + full_prompt = f"<|im_start|>system\n{critic_system}<|im_end|>\n<|im_start|>user\n{critic_prompt}<|im_end|>\n<|im_start|>assistant\n" + + payload = payload_template.copy() + payload["prompt"] = full_prompt + payload["seed"] = row.get("critic_seed", idx) + + port = get_next_port(server_type="critic") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, idx, url, payload))) + + return await asyncio.gather(*tasks) + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + rows = [] + for sample_idx, reasoning, is_correct in results: + rows.append({ + "sample_idx": sample_idx, + "critic_correct": is_correct, + "critic_feedback": reasoning, + }) + + return pd.DataFrame(rows) + + +def run_k_samples_with_critic( + prompt, + llm_server, + critic_llm_server, + k, + seed, + task, + example, + tokenizer, + eval_fn, + nums=None, + early_stop=False, + quiet=True, +): + """Run k samples with critic evaluation using vLLM batching.""" + # Generate k samples + outputs = batch_generate_samples(prompt, llm_server, k, seed, quiet=quiet) + + # Create dataframe with outputs + df_samples = pd.DataFrame({ + "sample_idx": range(k), + "output": outputs, + "seed": [seed + i for i in range(k)], + }) + + # If early stop mode, stop at first critic-correct + if early_stop: + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + + # Evaluate with critic + df_critic = batch_evaluate_with_critic( + pd.DataFrame([{"output": output, "seed_idx": idx}]), + task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + critic_correct = df_critic.iloc[0]["critic_correct"] if len(df_critic) > 0 else False + critic_feedback = df_critic.iloc[0]["critic_feedback"] if len(df_critic) > 0 else "" + + # Evaluate with ground truth + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + if critic_correct: + break + + return sample_results + else: + # Batch critic evaluation + df_critic = batch_evaluate_with_critic( + df_samples, task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + + # Merge critic results + df_samples = df_samples.merge(df_critic, left_index=True, right_on="sample_idx", how="left") + + # Process all results + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + critic_correct = row.get("critic_correct", False) + critic_feedback = row.get("critic_feedback", "") + + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + return sample_results + + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets") + parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap"], + help="Task to run") + parser.add_argument("--k", type=int, default=4, help="Number of samples per example") + parser.add_argument("--num_examples", "-n", type=int, default=None, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run") + parser.add_argument("--xrange", type=str, default=None, + help="Range of indices to run (format: 'start-end')") + parser.add_argument("--start", type=int, default=None, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index") + parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--use_critic", action="store_true", help="Use critic model for evaluation instead of ground truth") + parser.add_argument("--critic_model", type=str, default=MAIN_MODEL, help="Critic model to use for evaluation") + parser.add_argument("--critic_port", type=int, default=8000, help="vLLM server port for critic model (default: same as main model port)") + parser.add_argument("--critic_early_stop", action="store_true", help="Stop sampling after first critic-correct trace") + parser.add_argument("--critic_feedback_baseline", action="store_true", help="Use critic feedback as a separate baseline for post-hoc correction") + parser.add_argument("--seed", type=int, default=42, help="Base random seed") + parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation") + parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + log_level = logging.DEBUG if args.debug else logging.ERROR + logging.basicConfig(level=log_level, format="%(message)s") + + quiet_mode = not args.debug + + if quiet_mode: + with suppress_output(): + dataset = load_dataset_for_task(args.task) + else: + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + + llm_server = init_llm_server( + args.main_model, + max_tokens=args.max_tokens, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + critic_llm_server = None + if args.use_critic: + critic_llm_server = init_llm_server( + args.critic_model, + max_tokens=512, + port=args.critic_port, + temperature=0.2, + seed=args.seed, + ) + # logger.info(f"Using critic model: {args.critic_model} on port {args.critic_port}") + + # logger.info(f"Loading tokenizer for {args.main_model}...") + if quiet_mode: + with suppress_output(): + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + # logger.info("Tokenizer loaded successfully.") + + output_dirs = get_output_dirs(args.task, args.main_model, args.use_critic, args.critic_early_stop) + + total_examples = 0 + total_correct = 0 + total_correct_samples = 0 + total_samples = 0 + critic_correct_samples = 0 + critic_total_samples = 0 + total_tokens = 0 + total_tokens_all_samples = 0 + results = [] + + for idx in tqdm(indices, desc="Processing examples", unit="example"): + example = dataset[int(idx)] + if args.task == "game24": + nums = example["numbers"] + prompt = build_full_prompt(args.task, example, nums=nums) + eval_fn = lambda output: evaluate_game24_answer(output, nums) + options = None + else: + prompt = build_full_prompt(args.task, example) + gt = str(example.get("ground_truth", "")).strip() + if gt == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + if args.task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + eval_fn = lambda output: evaluate_mcq_answer(output, options, gt) + + #logger.info(f"---- Example {idx} ----") + + quiet_mode = not args.debug + + if args.use_critic: + sample_results = run_k_samples_with_critic( + prompt, llm_server, critic_llm_server, args.k, args.seed, + args.task, example, tokenizer, eval_fn, nums=(nums if args.task == "game24" else None), + early_stop=args.critic_early_stop, quiet=quiet_mode + ) + else: + outputs = batch_generate_samples(prompt, llm_server, args.k, args.seed, quiet=quiet_mode) + sample_results = [] + for output in outputs: + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=message, + tokens=token_count, + critic_correct=None, + )) + + if args.use_critic: + best_idx = next((i for i, r in enumerate(sample_results) if r.critic_correct), 0) + else: + best_idx = next((i for i, r in enumerate(sample_results) if r.correct), 0) + best_result = sample_results[best_idx] + any_correct = any(r.correct for r in sample_results) + correct_samples = sum(1 for r in sample_results if r.correct) + critic_correct_samples_example = sum(1 for r in sample_results if r.critic_correct) + + save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"]) + + total_examples += 1 + if any_correct: + total_correct += 1 + total_correct_samples += correct_samples + total_samples += len(sample_results) + critic_correct_samples += critic_correct_samples_example + critic_total_samples += len(sample_results) + total_tokens += best_result.tokens + total_tokens_all_samples += sum(r.tokens for r in sample_results) + + results.append({ + "idx": int(idx), + "best_idx": best_idx, + "any_correct": any_correct, + "best_correct": best_result.correct, + "best_critic_correct": best_result.critic_correct, + "best_extracted": best_result.extracted, + "best_message": best_result.message, + "best_critic_feedback": best_result.critic_feedback, + "best_tokens": best_result.tokens, + "all_tokens": [r.tokens for r in sample_results], + "all_correct": [r.correct for r in sample_results], + "all_critic_correct": [r.critic_correct for r in sample_results], + "all_critic_feedback": [r.critic_feedback for r in sample_results], + "options": options, + }) + + #logger.info(f"Best sample: {best_idx} | Correct in K: {any_correct}") + #logger.info(f"Best message: {best_result.message}") + + accuracy = total_correct / total_examples if total_examples else 0 + avg_best_tokens = total_tokens / total_examples if total_examples else 0 + avg_all_tokens = total_tokens_all_samples / total_examples if total_examples else 0 + + summary = { + "task": args.task, + "model": args.main_model, + "k": args.k, + "use_critic": args.use_critic, + "total_examples": total_examples, + "correct": total_correct, + "correct_samples": total_correct_samples, + "total_samples": total_samples, + "critic_correct_samples": critic_correct_samples, + "critic_total_samples": critic_total_samples, + "critic_accuracy": (critic_correct_samples / critic_total_samples) if critic_total_samples else 0, + "accuracy": accuracy, + "avg_best_tokens": avg_best_tokens, + "avg_all_tokens": avg_all_tokens, + "total_tokens_best": total_tokens, + "total_tokens_all_samples": total_tokens_all_samples, + "results": results, + } + + if args.use_critic: + summary["critic_model"] = args.critic_model + summary["critic_port"] = args.critic_port + summary["critic_early_stop"] = args.critic_early_stop + summary["critic_feedback_baseline"] = args.critic_feedback_baseline + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + # logger.info(f"Saved summary to {summary_path}") diff --git a/examples/TTSwithVerification/cleanup_vllm.sh b/examples/TTSwithVerification/cleanup_vllm.sh new file mode 100755 index 0000000..f552ec1 --- /dev/null +++ b/examples/TTSwithVerification/cleanup_vllm.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Cleanup script to kill all vLLM processes and Python instances + +echo "Stopping all vLLM processes..." +pkill -9 -f "vllm.entrypoints.openai.api_server" + +echo "Stopping Python processes..." +pkill -9 -f "bestofk_baseline.py" + +sleep 2 + +echo "Verifying all processes stopped..." +if pgrep -f "vllm.entrypoints.openai.api_server" > /dev/null; then + echo "WARNING: Some vLLM processes still running" +else + echo "✓ All vLLM processes stopped" +fi + +if pgrep -f "bestofk_baseline.py" > /dev/null; then + echo "WARNING: Some Python processes still running" +else + echo "✓ All Python processes stopped" +fi + +echo "Cleanup complete" diff --git a/examples/TTSwithVerification/run_experiments.py b/examples/TTSwithVerification/run_experiments.py new file mode 100644 index 0000000..a52952b --- /dev/null +++ b/examples/TTSwithVerification/run_experiments.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Job scheduler for running bestofk_baseline.py experiments sequentially. +""" +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path + +# Base command +BASE_CMD = "python /data/b-pchanda/interwhen/examples/TTSwithVerification/bestofk_baseline.py" + +# Define your experiment configurations +EXPERIMENTS = [ + # Maze experiments + { + "name": "maze_k4_critic", + "args": "--task maze --k 4 --use_critic" + }, + { + "name": "maze_k4_critic_earlystop", + "args": "--task maze --k 4 --use_critic --critic_early_stop" + }, + { + "name": "maze_k4_no_critic", + "args": "--task maze --k 4" + }, + + # Game24 experiments + { + "name": "game24_k4_critic", + "args": "--task game24 --k 4 --use_critic" + }, + { + "name": "game24_k4_critic_earlystop", + "args": "--task game24 --k 4 --use_critic --critic_early_stop" + }, + { + "name": "game24_k4_no_critic", + "args": "--task game24 --k 4" + }, + + # Spatialmap experiments + { + "name": "spatialmap_k4_critic", + "args": "--task spatialmap --k 4 --use_critic" + }, + { + "name": "spatialmap_k4_critic_earlystop", + "args": "--task spatialmap --k 4 --use_critic --critic_early_stop" + }, + { + "name": "spatialmap_k4_no_critic", + "args": "--task spatialmap --k 4" + }, +] + + +def run_experiment(exp_config, exp_num, total_exps): + """Run a single experiment.""" + name = exp_config["name"] + args = exp_config["args"] + + print("\n" + "="*80) + print(f"Experiment [{exp_num}/{total_exps}]: {name}") + print(f"Command: {BASE_CMD} {args}") + print("="*80) + + start_time = time.time() + start_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + try: + # Run the command + result = subprocess.run( + f"{BASE_CMD} {args}", + shell=True, + capture_output=False, # Show output in real-time + text=True + ) + + elapsed = time.time() - start_time + end_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + if result.returncode == 0: + print(f"\n✓ Experiment '{name}' completed successfully") + print(f" Started: {start_ts}") + print(f" Finished: {end_ts}") + print(f" Duration: {elapsed:.1f}s ({elapsed/60:.1f} min)") + return True, elapsed + else: + print(f"\n✗ Experiment '{name}' failed with exit code {result.returncode}") + print(f" Duration: {elapsed:.1f}s") + return False, elapsed + + except KeyboardInterrupt: + print(f"\n\n⚠ Experiment '{name}' interrupted by user") + raise + except Exception as e: + elapsed = time.time() - start_time + print(f"\n✗ Experiment '{name}' failed with exception: {e}") + print(f" Duration: {elapsed:.1f}s") + return False, elapsed + + +def main(): + """Run all experiments sequentially.""" + print("="*80) + print("JOB SCHEDULER - Running experiments sequentially") + print("="*80) + print(f"Total experiments: {len(EXPERIMENTS)}") + print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + results = [] + total_start = time.time() + + try: + for i, exp in enumerate(EXPERIMENTS, 1): + success, duration = run_experiment(exp, i, len(EXPERIMENTS)) + results.append({ + "name": exp["name"], + "success": success, + "duration": duration + }) + + # Brief pause between experiments + if i < len(EXPERIMENTS): + print("\nWaiting 5 seconds before next experiment...") + time.sleep(5) + + except KeyboardInterrupt: + print("\n\n⚠ Job scheduler interrupted by user") + + finally: + # Print summary + total_elapsed = time.time() - total_start + print("\n" + "="*80) + print("SUMMARY") + print("="*80) + + successful = sum(1 for r in results if r["success"]) + failed = len(results) - successful + + print(f"\nCompleted: {len(results)}/{len(EXPERIMENTS)} experiments") + print(f"Successful: {successful}") + print(f"Failed: {failed}") + print(f"\nTotal time: {total_elapsed:.1f}s ({total_elapsed/60:.1f} min, {total_elapsed/3600:.2f} hrs)") + + print("\nDetailed results:") + for i, r in enumerate(results, 1): + status = "✓" if r["success"] else "✗" + print(f" {i}. {status} {r['name']:40s} - {r['duration']:.1f}s ({r['duration']/60:.1f} min)") + + if failed > 0: + print("\nFailed experiments:") + for i, r in enumerate(results, 1): + if not r["success"]: + print(f" {i}. {r['name']}") + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/TTSwithVerification/start_vllm_multiprocess.sh b/examples/TTSwithVerification/start_vllm_multiprocess.sh new file mode 100755 index 0000000..080b58b --- /dev/null +++ b/examples/TTSwithVerification/start_vllm_multiprocess.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +# Start 4 vLLM processes with explicit GPU assignment +# Process 1: GPUs 0-1, Port 8000 +# Process 2: GPUs 2-3, Port 8001 +# Process 3: GPUs 4-5, Port 8002 +# Process 4: GPUs 6-7, Port 8003 + +MODEL="Qwen/QwQ-32B" +GPU_MEMORY=0.4 +TENSOR_PARALLEL=2 + +echo "Killing any existing vLLM processes..." +pkill -9 -f "vllm.entrypoints.openai.api_server" +sleep 2 + +echo "Starting 4 vLLM processes..." + +# Process 1 - GPUs 0,1 +( + export CUDA_VISIBLE_DEVICES=0,1 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8000 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8000.log 2>&1 +) & +PID1=$! +echo "Started Process 1 (GPUs 0-1, Port 8000) - PID: $PID1" + +sleep 5 + +# Process 2 - GPUs 2,3 +( + export CUDA_VISIBLE_DEVICES=2,3 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8001 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8001.log 2>&1 +) & +PID2=$! +echo "Started Process 2 (GPUs 2-3, Port 8001) - PID: $PID2" + +sleep 5 + +# Process 3 - GPUs 4,5 +( + export CUDA_VISIBLE_DEVICES=4,5 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8002 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8002.log 2>&1 +) & +PID3=$! +echo "Started Process 3 (GPUs 4-5, Port 8002) - PID: $PID3" + +sleep 5 + +# Process 4 - GPUs 6,7 +( + export CUDA_VISIBLE_DEVICES=6,7 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8003 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8003.log 2>&1 +) & +PID4=$! +echo "Started Process 4 (GPUs 6-7, Port 8003) - PID: $PID4" + +echo "" +echo "All 4 vLLM processes started successfully." +echo "Process PIDs: $PID1 $PID2 $PID3 $PID4" +echo "" +echo "Log files:" +echo " /tmp/vllm_8000.log - Process 1" +echo " /tmp/vllm_8001.log - Process 2" +echo " /tmp/vllm_8002.log - Process 3" +echo " /tmp/vllm_8003.log - Process 4" +echo "" +echo "To stop all processes, run:" +echo " pkill -9 -f 'vllm.entrypoints.openai.api_server'" +echo "" +echo "Waiting for processes to initialize (this may take 60-120 seconds)..." +echo "" + +# Wait for all processes +wait $PID1 $PID2 $PID3 $PID4