diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py new file mode 100644 index 00000000..7ece227c --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -0,0 +1,118 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content +from ais_bench.benchmark.datasets import ( + Aime2025Dataset, + Aime2025JDGDataset, +) +from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator + + +aime2025_reader_cfg = dict(input_columns=["question"], output_column="answer") + + +aime2025_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt="{question}\nRemember to put your final answer within \\boxed{}.", + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +GRADER_TEMPLATE = """ + Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly. + + Here are some evaluation criteria: + 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct. + 2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question. + 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. + 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. + 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + 6. If the candidate's answer is semantically incomplete at the end, please judge it as inconsistent. + + Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: + A: Means the answer is consistent with the standard answer. + B: Means the answer is inconsistent with the standard answer. + Just return the letters "A" or "B", with no text around it. + + Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. + + + : \n{question}\n\n\n + : \n{answer}\n\n\n + : \n{model_answer}\n\n\n + + Judging the correctness of candidates' answers, please return the the letters "A" or "B" first before your thinking: +""".strip() + +aime2025_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "answer", "model_answer"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + abbr="judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="localhost", + host_port=8080, + url="", + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), + ), + judge_dataset_type=Aime2025JDGDataset, + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role='SYSTEM', + fallback_role='HUMAN', + prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.", + ) + ], + round=[ + dict(role='HUMAN', prompt=GRADER_TEMPLATE), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +aime2025_eval_cfg = dict( + evaluator=dict(type=LLMJudgeCorrectEvaluator), + pred_postprocessor=dict(type=get_a_or_b), +) + +aime2025_datasets = [ + dict( + abbr="aime2025", + type=Aime2025Dataset, + path="ais_bench/datasets/aime2025/aime2025.jsonl", + reader_cfg=aime2025_reader_cfg, + infer_cfg=aime2025_infer_cfg, + judge_infer_cfg=aime2025_judge_infer_cfg, + eval_cfg=aime2025_eval_cfg, + ) +] diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py new file mode 100644 index 00000000..4a156775 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py @@ -0,0 +1,143 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_prompt_template.icl_prompt_template_mm import MMPromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content +from ais_bench.benchmark.datasets.g_edit import ( + GEditDataset, + GEditSCJDGDataset, + GEditPQJDGDataset, +) +from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list, LMMJudgeImageEditEvaluator + +SC_GRADER_TEMPLATE = """ +RULES: + +Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. +The objective is to evaluate how successfully the editing instruction has been executed in the second image. + +Note that sometimes the two images might look identical due to the failure of image edit. + +From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +Editing instruction: {question} +""".strip() + +PQ_GRADER_TEMPLATE = """ +RULES: + +The image is an AI-generated image. +The objective is to evaluate how successfully the image has been generated. + +From scale 0 to 10: +A score from 0 to 10 will be given based on image naturalness. +( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. +) +A second score from 0 to 10 will rate the image artifacts. +( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. +) +Put the score in a list such that output score = [naturalness, artifacts] +""".strip() + +JDG_DATASETS_CLASS_MAP = { + "SC": GEditSCJDGDataset, + "PQ": GEditPQJDGDataset, +} + +JDG_TEMPLATE_MAP = { + "SC": SC_GRADER_TEMPLATE, + "PQ": PQ_GRADER_TEMPLATE, +} + + +gedit_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='task_type' +) + + +gedit_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role="HUMAN", prompt_mm={ + "text": {"type": "text", "text": "{question}"}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=LMMGenInferencer) +) + +gedit_datasets = [] + +for metric in ["SC", "PQ"]: + gedit_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "model_answer", "image"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + abbr=f"{metric}_judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="localhost", + host_port=8080, + url="", + max_out_len=512, + batch_size=16, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), + ), + judge_dataset_type=JDG_DATASETS_CLASS_MAP[metric], + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + "text": {"type": "text", "text": JDG_TEMPLATE_MAP[metric]}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + gedit_eval_cfg = dict( + evaluator=dict(type=LMMJudgeImageEditEvaluator, metric=metric), + pred_postprocessor=dict(type=get_lmm_point_list), + ) + + gedit_datasets.append( + dict( + abbr=f"gedit", + type=GEditDataset, + path="ais_bench/datasets/GEdit-Bench", + reader_cfg=gedit_reader_cfg, + infer_cfg=gedit_infer_cfg, + judge_infer_cfg=gedit_judge_infer_cfg, + eval_cfg=gedit_eval_cfg, + ) + ) \ No newline at end of file diff --git a/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py new file mode 100644 index 00000000..c3e92dc4 --- /dev/null +++ b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py @@ -0,0 +1,18 @@ +from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel + +models = [ + dict( + attr="local", # local or service + type=QwenImageEditModel, # transformers >= 4.33.0 用这个,prompt 是构造成对话格式 + abbr='qwen-image-edit', + path='/home/yanhe/models/Qwen-Image-Edit-2509/', # path to model dir, current value is just a example + device_kwargs=dict( + ), + infer_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained + num_inference_steps=50, + num_images_per_prompt=1, + ), + run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务 + batch_size=1, # 每次推理的batch size + ) +] \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 6e67b07d..b6b13a1c 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -1,16 +1,15 @@ -import json +import json from datasets import Dataset from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset -from .base import BaseDataset - +from ais_bench.benchmark.datasets.base import BaseDataset @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): - @staticmethod def load(path, **kwargs): path = get_data_path(path) @@ -20,3 +19,9 @@ def load(path, **kwargs): line = json.loads(line.strip()) dataset.append(line) return Dataset.from_list(dataset) + + +@LOAD_DATASET.register_module() +class Aime2025JDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return Aime2025Dataset diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index de062a5d..3a98ce4c 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,5 +1,7 @@ from abc import abstractmethod -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Type +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm from datasets import Dataset, DatasetDict from datasets.utils.logging import disable_progress_bar @@ -8,18 +10,26 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.tasks.base import TaskStateManager disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination +JDG_DATASET_LOAD_BATCH_SIZE = 10 + + class BaseDataset: def __init__(self, reader_cfg: Optional[Dict] = {}, k: Union[int, List[int]] = 1, n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, **kwargs): # Validate k and n parameters self.logger = AISLogger() + if task_state_manager is not None: + self._init_task_state_manager(task_state_manager) + max_k = max(k) if isinstance(k, List) else k if max_k > n: raise ParameterValueError( @@ -35,10 +45,17 @@ def __init__(self, self._init_reader(**reader_cfg) self.repeated_dataset(self.abbr, n) # this process will update self.dataset and self.reader.dataset + def _init_task_state_manager(self, task_state_manager): + self.task_state_manager = task_state_manager def _init_reader(self, **kwargs): self.reader = DatasetReader(self.dataset, **kwargs) + def update_task_state(self, state: Dict): + if self.task_state_manager is not None: + self.task_state_manager.update_task_state(state) + else: + self.logger.warning("Task state manager is not initialized, cannot update task state") def repeated_dataset(self, abbr: str, n: int): # Create repeated indices in batches to avoid generating an oversized index list at once @@ -106,5 +123,101 @@ def test(self): return self.reader.dataset['test'] @abstractmethod - def load(**kwargs) -> Union[Dataset, DatasetDict]: + def load(self, **kwargs) -> Union[Dataset, DatasetDict]: pass + + +class BaseJDGDataset(BaseDataset): + def __init__(self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, + **kwargs): + self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, task_state_manager, **kwargs) + super().__init__(reader_cfg, k, n, task_state_manager, **kwargs) + + def load(self, predictions_path: str, **kwargs): + + dataset_content = self.dataset_instance.dataset["test"] + + predictions: list = self._load_from_predictions(predictions_path) + + if isinstance(dataset_content, Dataset): + datasets_to_process = [dataset_content] + elif isinstance(dataset_content, DatasetDict): + datasets_to_process = [dataset_content[key] for key in dataset_content] + else: + raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + + return self._process_predictions(datasets_to_process, predictions) + + def _process_predictions(self, datasets_to_process: List[Dataset], predictions: list) -> Dataset: + dataset_batches = [] + current_batch = [] + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for dataset in datasets_to_process: + for item in predictions: + future = executor.submit(self._process_single_item, dataset, item) + futures.append(future) + + with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: + for i, future in enumerate(as_completed(futures)): + result = future.result() + current_batch.append(result) + + if len(current_batch) >= JDG_DATASET_LOAD_BATCH_SIZE: + dataset_batches.append(Dataset.from_list(current_batch)) + current_batch = [] + + pbar.update(1) + self.update_task_state( + { + "total_count": len(futures), + "progress_description": "Processing predictions", + "finish_count": i + 1, + } + ) + pbar.refresh() + + if current_batch: + dataset_batches.append(Dataset.from_list(current_batch)) + + if dataset_batches: + if len(dataset_batches) == 1: + return dataset_batches[0] + else: + from datasets import concatenate_datasets + return concatenate_datasets(dataset_batches) + else: + return Dataset.from_list([]) + + @abstractmethod + def _load_from_predictions(self, prediction_path: str) -> Dict: + pass + + @abstractmethod + def _get_dataset_class(self): + return BaseDataset + + def _modify_dataset_item(self, dataset_item, pred_item): + dataset_item["model_answer"] = pred_item["prediction"] + + def _process_single_item(self, dataset_content, pred_item): + item_dict = dataset_content[int(pred_item["id"])] + self._modify_dataset_item(item_dict, pred_item) + item_dict["model_pred_uuid"] = pred_item["uuid"] + return item_dict + + def _init_org_datasets_instance( + self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, + **kwargs): + dataset_class = self._get_dataset_class() + return dataset_class(reader_cfg, k, n, task_state_manager, **kwargs) + diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py new file mode 100644 index 00000000..18300705 --- /dev/null +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -0,0 +1,118 @@ +import json +from datasets import Dataset, load_from_disk, concatenate_datasets +from concurrent.futures import ThreadPoolExecutor, as_completed + +from ais_bench.benchmark.registry import LOAD_DATASET +from ais_bench.benchmark.openicl import BaseEvaluator +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.lmm_judge import ImgSCJDGDataset, ImgPQJDGDataset +from ais_bench.benchmark.utils.image_process import pil_to_base64 +from PIL import Image +from tqdm import tqdm + +from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + +GEDIT_COUNT = 1212 # total 1212 cases, could change for quick test + +class GEditEvaluator(BaseEvaluator): + def score(self, predictions, references): + details = [] + for i, pred in enumerate(predictions): + details.append({ + 'pred': pred, + 'ref': references[i], + }) + result = {'accuracy': 100 * len(predictions) / len(references), 'details': details} + return result + +@LOAD_DATASET.register_module() +class GEditDataset(BaseDataset): + def load(self, path, use_raw=False, split_count=1, split_index=0, **kwargs): + path = get_data_path(path) + self.update_task_state( + { + "status": "loading dataset", + } + ) + dataset = load_from_disk(path) + + # 数据集切分:分成 split_count 份,取第 split_index 份 + if split_count > 1: + total_len = len(dataset) + base_size = total_len // split_count # 每份基础大小 + remainder = total_len % split_count # 余数 + + # 计算当前 split_index 的起始和结束位置 + # 前 remainder 份每份多一个元素 + if split_index < remainder: + start_idx = split_index * (base_size + 1) + end_idx = start_idx + (base_size + 1) + else: + start_idx = remainder * (base_size + 1) + (split_index - remainder) * base_size + end_idx = start_idx + base_size + + dataset = dataset.select(range(start_idx, end_idx)) + else: + dataset = dataset.select(range(GEDIT_COUNT)) + + if use_raw: + image_column = 'input_image_raw' + else: + image_column = 'input_image' + + def process_example_to_dataset(example): + """处理单条数据并转换为 Dataset""" + image_url = pil_to_base64(example[image_column], "PNG") + example['content'] = AIS_IMAGE_START + image_url + AIS_CONTENT_TAG \ + + AIS_TEXT_START + example['instruction'] + AIS_CONTENT_TAG + # 使用 from_dict 替代 from_list 以提高性能 + data_dict = {key: [example[key]] for key in example.keys()} + return Dataset.from_dict(data_dict) + + processed_datasets = [None] * len(dataset) + + with ThreadPoolExecutor(max_workers=8) as executor: + # 提交所有任务 + with tqdm(total=len(dataset), desc=f"Convert GEdit dataset to base64, split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: + futures = {} + for i, example in enumerate(dataset): + future = executor.submit(process_example_to_dataset, example) + futures[future] = i + self.update_task_state( + { + "total_count": len(dataset), + "progress_description": f"Convert GEdit dataset to base64", + "finish_count": i + 1, + } + ) + submit_pbar.update(1) + + # 收集处理完成的 Dataset + with tqdm(total=len(dataset), desc="Processing GEdit dataset", unit="example") as pbar: + for future in as_completed(futures): + idx = futures[future] + self.update_task_state( + { + "total_count": len(dataset), + "progress_description": f"Processing GEdit dataset", + "finish_count": idx + 1, + } + ) + processed_datasets[idx] = future.result() + pbar.update(1) + + # 合并所有 Dataset + return concatenate_datasets(processed_datasets) + + +@LOAD_DATASET.register_module() +class GEditSCJDGDataset(ImgSCJDGDataset): + def _get_dataset_class(self): + return GEditDataset + + +@LOAD_DATASET.register_module() +class GEditPQJDGDataset(ImgPQJDGDataset): + def _get_dataset_class(self): + return GEditDataset \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/utils/datasets.py b/ais_bench/benchmark/datasets/utils/datasets.py index 9f473d6a..6bb31701 100644 --- a/ais_bench/benchmark/datasets/utils/datasets.py +++ b/ais_bench/benchmark/datasets/utils/datasets.py @@ -69,7 +69,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count data_list (list): Data list. sample_mode (str): Sample mode. request_count (int): Request count. - + Raises: ValueError: If sample mode is not supported. ValueError: If request count is negative. @@ -101,7 +101,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count return shuffle_data else: raise ValueError(f"Sample mode: {sample_mode} is not supported!") - + def get_meta_json(dataset_path, meta_path): ori_meta_path = meta_path if not meta_path: @@ -389,7 +389,7 @@ def _to_float(text: str): return relative_change <= max_relative_change else: return prediction.lower() == target.lower() - + def anls_compute(groundtruth, prediction): gt_answer = ' '.join(groundtruth.strip().lower().split()) det_answer = ' '.join(prediction.strip().lower().split()) diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py new file mode 100644 index 00000000..78c3a190 --- /dev/null +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -0,0 +1,60 @@ +import re +import os +from PIL import Image + +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl +logger = AISLogger() + + +@TEXT_POSTPROCESSORS.register_module("get_a_or_b") +def get_a_or_b(pred: str) -> str: + """从模型回复中提取A或B""" + match = re.search(r'[AB]', pred[-1:]) + return match.group(0) if match else 'B' + + +class LLMJudgeDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if not os.path.exists(prediction_path): + logger.warning(f"Prediction file does not exist: {prediction_path}") + return [] + + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id', 0)) + return preds + + +@ICL_EVALUATORS.register_module() +class LLMJudgeCorrectEvaluator(BaseEvaluator): + + def __init__(self): + super().__init__() + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + correct = 0 + count = 0 + details = [] + for i, j in zip(predictions, references): + detail = {'pred': i, 'answer': j, 'correct': False} + count += 1 + if i == "A": + correct += 1 + detail['correct'] = True + details.append(detail) + result = {'accuracy': 100 * correct / count, 'details': details} + return result \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py new file mode 100644 index 00000000..95666c07 --- /dev/null +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -0,0 +1,148 @@ +import re +import os +import json +import base64 +import concurrent.futures +from io import BytesIO +from PIL import Image +from tqdm import tqdm + +from ais_bench.benchmark.datasets.needlebench_v2 import origin +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl +from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError +from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES +logger = AISLogger() + + +@TEXT_POSTPROCESSORS.register_module("get_lmm_point_list") +def get_lmm_point_list(pred: str) -> str: + """从模型回复中提取列表的字符串""" + match = re.search(r'\[\s*\d+(?:\s*,\s*\d+)*\s*\]', pred) + return match.group(0) if match else '[]' + + +class LMMImgJDGDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """从prediction中拿到对应图片相对路径,将这个路径的图片加载并转换为Base64字符串. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if not os.path.exists(prediction_path): + return [] + + preds = load_jsonl(prediction_path) + base_path = os.path.dirname(prediction_path) + + # 定义图片处理函数 + def process_image(index, pred_item): + # 现在可以使用index来知道pred_item是preds中的第几个 + image_path = os.path.join(base_path, pred_item.get('prediction', '')) + if image_path and os.path.exists(image_path): + try: + # 加载图片 + with Image.open(image_path) as img: + # 转换为RGB格式 + img = img.convert('RGB') + # 保存到BytesIO + buffered = BytesIO() + img.save(buffered, format="PNG") + # 转换为Base64字符串 + img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + # 更新pred中的image字段为Base64字符串 + pred_item['prediction'] = img_base64 + self.update_task_state( + { + "total_count": len(preds), + "progress_description": f"Convert prediction images to base64", + "finish_count": index, + } + ) + except Exception as e: + raise AISBenchRuntimeError(DSET_CODES.UNKNOWN_ERROR, f"Failed to process image {image_path} at index {index}: {e}") + return pred_item + + # 使用并行处理加速图片处理 + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + # 使用tqdm显示进度 + processed_preds = list(tqdm( + executor.map(lambda x: process_image(x[0], x[1]), enumerate(preds)), + total=len(preds), + desc="Convert prediction images to base64", + unit="image" + )) + + processed_preds.sort(key=lambda x: x.get('id', 0)) + return processed_preds + + +class ImgSCJDGDataset(LMMImgJDGDataset): + def _modify_dataset_item(self, dataset_item, pred_item): + for item in dataset_item["content"].split(AIS_CONTENT_TAG): + if item.startswith(AIS_TEXT_START): + question = item.replace(AIS_TEXT_START, "") + elif item.startswith(AIS_IMAGE_START): + org_image_url = item.replace(AIS_IMAGE_START, "") + self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") + dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG + + +class ImgPQJDGDataset(LMMImgJDGDataset): + def _modify_dataset_item(self, dataset_item, pred_item): + for item in dataset_item["content"].split(AIS_CONTENT_TAG): + if item.startswith(AIS_TEXT_START): + question = item.replace(AIS_TEXT_START, "") + elif item.startswith(AIS_IMAGE_START): + org_image_url = item.replace(AIS_IMAGE_START, "") + self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") + dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG + + +POINT_KEY_LIST_MAP = { + "SC": ["editing success", "over editing"], + "PQ": ["naturalness", "artifacts"] +} + + +@ICL_EVALUATORS.register_module() +class LMMJudgeImageEditEvaluator(BaseEvaluator): + def __init__(self, metric: str = "SC"): + self.metric = metric + self.point_key_list = POINT_KEY_LIST_MAP[metric] + super().__init__() + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + + # 将get_lmm_point_list获取的字符串格式的list转换成list格式 + if not all(isinstance(pred, str) for pred in predictions): + return {'error': 'preds must be strings'} + predictions = [json.loads(get_lmm_point_list(pred)) for pred in predictions] + + total_score = 0 + count = 0 + details = [] + for pred, ref in zip(predictions, references): + if len(pred) != len(self.point_key_list): + detail = {'pred': pred, 'org_uuid': ref, 'eval_success': False, 'failed reason': 'prediction format error, length of point list not equal to point key list'} + else: + detail = {'pred': {key: score for key, score in zip(self.point_key_list, pred)}, 'org_uuid': ref, 'eval_success': True} + count += 1 + if detail['eval_success']: + total_score += min(pred) + details.append(detail) + result = {f"{self.metric}": total_score / count if count > 0 else 0.0, 'details': details} + return result \ No newline at end of file diff --git a/ais_bench/benchmark/utils/config/build.py b/ais_bench/benchmark/utils/config/build.py index 2634fc86..bd7f33db 100644 --- a/ais_bench/benchmark/utils/config/build.py +++ b/ais_bench/benchmark/utils/config/build.py @@ -1,6 +1,7 @@ import os import copy import ipaddress +from typing import Any from mmengine.config import ConfigDict @@ -113,11 +114,13 @@ def check(condition, key, message): return errors -def build_dataset_from_cfg(dataset_cfg: ConfigDict): +def build_dataset_from_cfg(dataset_cfg: ConfigDict, task_state_manager: Any = None): logger.debug(f"Building dataset from config: type={dataset_cfg.get('type')} abbr={dataset_cfg.get('abbr')}") dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg.pop("infer_cfg", None) dataset_cfg.pop("eval_cfg", None) + if task_state_manager is not None: + dataset_cfg["task_state_manager"] = task_state_manager return LOAD_DATASET.build(dataset_cfg) diff --git a/ais_bench/benchmark/utils/file/file.py b/ais_bench/benchmark/utils/file/file.py index d6bfde67..f15a26c4 100644 --- a/ais_bench/benchmark/utils/file/file.py +++ b/ais_bench/benchmark/utils/file/file.py @@ -1,6 +1,8 @@ from typing import List, Tuple, Union import os import json +import mmap +import orjson import fnmatch import tabulate @@ -226,4 +228,31 @@ def check_mm_custom(path): return False if line["type"] not in ["image", "video", "audio"]: return False - return True \ No newline at end of file + return True + +def load_jsonl(path: str) -> List[dict]: + """Load JSONL file into a list of dictionaries. + + Args: + path: Path to the JSONL file + + Returns: + List of dictionaries, each representing a line in the JSONL file + """ + preds = [] + with open(path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + for line in iter(mm.readline, b""): + preds.append(orjson.loads(line)) + return preds + +def dump_jsonl(data: List[dict], path: str): + """Dump a list of dictionaries to a JSONL file. + + Args: + data: List of dictionaries to be dumped + path: Path to the output JSONL file + """ + with open(path, 'wb') as f: + for item in data: + f.write(orjson.dumps(item) + b'\n') \ No newline at end of file diff --git a/ais_bench/benchmark/utils/image_process.py b/ais_bench/benchmark/utils/image_process.py new file mode 100644 index 00000000..0b7b848d --- /dev/null +++ b/ais_bench/benchmark/utils/image_process.py @@ -0,0 +1,16 @@ +import base64 +from io import BytesIO +from PIL import Image +from ais_bench.benchmark.utils.logging.exceptions import AISBenchDataContentError +from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES + +def pil_to_base64(image, format="JPEG"): + """ + Convert PIL Image to base64 string + """ + if not isinstance(image, Image.Image): + raise AISBenchDataContentError(UTILS_CODES.UNKNOWN_ERROR, "Input must be a PIL Image object") + buffered = BytesIO() + image.save(buffered, format) + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str \ No newline at end of file diff --git a/ais_bench/benchmark/utils/prompt/prompt.py b/ais_bench/benchmark/utils/prompt/prompt.py index 6f0d1aa2..b485854d 100644 --- a/ais_bench/benchmark/utils/prompt/prompt.py +++ b/ais_bench/benchmark/utils/prompt/prompt.py @@ -157,13 +157,13 @@ def format_mm(self, **kwargs) -> PromptList: if item.startswith(AIS_TEXT_START): question = item.replace(AIS_TEXT_START, "") question = {"question": question} - text_content = mm['text'].copy() + text_content = deepcopy(mm['text']) text_content['text'] = safe_format(text_content['text'], **question) contents.append(text_content) elif item.startswith(AIS_IMAGE_START): image = item.replace(AIS_IMAGE_START, "") image = {"image": image} - image_content = mm['image'].copy() + image_content = deepcopy(mm['image']) if isinstance(image_content['image_url'], dict): image_content['image_url']['url'] = safe_format(image_content['image_url']['url'], **image) else: @@ -172,7 +172,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_VIDEO_START): video = item.replace(AIS_VIDEO_START, "") video = {"video": video} - video_content = mm['video'].copy() + video_content = deepcopy(mm['video']) if isinstance(video_content['video_url'], dict): video_content['video_url']['url'] = safe_format(video_content['video_url']['url'], **video) else: @@ -181,7 +181,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_AUDIO_START): audio = item.replace(AIS_AUDIO_START, "") audio = {"audio": audio} - audio_content = mm['audio'].copy() + audio_content = deepcopy(mm['audio']) if isinstance(audio_content['audio_url'], dict): audio_content['audio_url']['url'] = safe_format(audio_content['audio_url']['url'], **audio) else: