From ac8bef8128a00fdb39c34d856fa733d05d7f2b83 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Mon, 2 Feb 2026 14:16:47 +0800 Subject: [PATCH 01/59] judge llm --- ais_bench/benchmark/cli/workers.py | 139 +++++++++++++++++- .../aime2025/aime2025_gen_0_shot_llmjudge.py | 117 +++++++++++++++ ais_bench/benchmark/datasets/aime2025.py | 24 ++- ais_bench/benchmark/datasets/base.py | 19 +++ .../benchmark/datasets/utils/datasets.py | 6 +- .../benchmark/datasets/utils/llm_judge.py | 36 +++++ ais_bench/benchmark/utils/file/file.py | 31 +++- 7 files changed, 358 insertions(+), 14 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py create mode 100644 ais_bench/benchmark/datasets/utils/llm_judge.py diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index ca997164..f75cce7e 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,3 +1,4 @@ +import os import os.path as osp import copy from abc import ABC, abstractmethod @@ -8,12 +9,15 @@ from ais_bench.benchmark.registry import PARTITIONERS, RUNNERS, build_from_cfg from ais_bench.benchmark.utils.config.run import get_config_type from ais_bench.benchmark.utils.logging.logger import AISLogger +from ais_bench.benchmark.utils.logging.exceptions import PredictionInvalidException +from ais_bench.benchmark.utils.logging.error_codes import TMAN_CODES from ais_bench.benchmark.partitioners import NaivePartitioner from ais_bench.benchmark.runners import LocalRunner from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need +from ais_bench.benchmark.utils.file.file import load_jsonl, dump_jsonl logger = AISLogger() @@ -75,7 +79,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Merging datasets with the same model and inferencer...") tasks = self._merge_datasets(tasks) - runner = RUNNERS.build(cfg.infer.runner) + runner = RUNNERS.build(cfg.judge_infer.runner) runner(tasks) logger.info("Inference tasks completed.") @@ -108,6 +112,117 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task.attack = cfg.attack +class JudgeInfer(BaseWorker): + def update_cfg(self, cfg: ConfigDict) -> None: + def get_task_type() -> str: + if cfg["models"][0]["attr"] == "service": + return get_config_type(OpenICLApiInferTask) + else: + return get_config_type(OpenICLInferTask) + + new_cfg = dict( + judge_infer=dict( + partitioner=dict(type=get_config_type(NaivePartitioner)), + runner=dict( + max_num_workers=self.args.max_num_workers, + max_workers_per_gpu=self.args.max_workers_per_gpu, + debug=self.args.debug, + task=dict(type=get_task_type()), + type=get_config_type(LocalRunner), + ), + ), + ) + + cfg.merge_from_dict(new_cfg) + if cfg.cli_args.debug: + cfg.judge_infer.runner.debug = True + cfg.judge_infer.partitioner["out_dir"] = osp.join(cfg["work_dir"], "predictions/") + return cfg + + def do_work(self, cfg: ConfigDict): + partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner) + logger.info("Starting inference tasks...") + tasks = partitioner(cfg) + + # delete the tasks without judge_infer_cfg + new_tasks = [] + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + new_tasks.append(task) + tasks = new_tasks + if len(tasks) == 0: + return + + # update tasks cfg before run + self._update_tasks_cfg(tasks, cfg) + + if ( + cfg.get("cli_args", {}).get("merge_ds", False) + or cfg.get("cli_args", {}).get("mode") == "perf" # performance mode will enable merge datasets by default + ): + logger.info("Merging datasets with the same model and inferencer...") + tasks = self._merge_datasets(tasks) + + runner = RUNNERS.build(cfg.judge_infer.runner) + runner(tasks) + self._result_post_process(tasks, cfg) + logger.info("Inference tasks completed.") + + def _merge_datasets(self, tasks): + # merge datasets with the same model, dataset type and inferencer + task_groups = defaultdict(list) + for task in tasks: + key = ( + task["models"][0]["abbr"] # same model + + "_" + + str(task['datasets'][0][0]['type']) # same dataset type + + "_" + + str(task["datasets"][0][0]["infer_cfg"]["inferencer"]) # same inferencer with the same args + ) + task_groups[key].append(task) + new_tasks = [] + for key, task_group in task_groups.items(): + new_task = copy.deepcopy(task_group[0]) + if len(task_group) > 1: + for t in task_group[1:]: + new_task["datasets"][0].extend(t["datasets"][0]) + new_tasks.append(new_task) + return new_tasks + + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): + # update parameters to correct sub cfg + if hasattr(cfg, "attack"): + for task in tasks: + cfg.attack.dataset = task.datasets[0][0].abbr + task.attack = cfg.attack + + # update judge cfgs to model cfgs and data + for task in tasks: + task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + if not osp.exists(task["datasets"][0][0]["predictions_path"]): + raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + model_abbr = task["models"][0]["abbr"] + task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") + task["models"][0]["abbr"] = model_abbr + task["datasets"][0][0]["type"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_dataset_type") + task["datasets"][0][0]["reader_cfg"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_reader_cfg") + task["datasets"][0][0]["infer_cfg"] = task["datasets"][0][0].pop("judge_infer_cfg") + + def _result_post_process(self, tasks, cfg: ConfigDict): + # Reconstruct the judge infer predictions to normal predictions format + for task in tasks: + model_org_prediction_path = task["datasets"][0][0]["predictions_path"] + model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)} + judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + judge_preds: list = load_jsonl(judge_org_prediction_path) + for i, pred in enumerate(judge_preds): + uuid = pred["gold"] + judge_preds[i]["id"] = model_preds[uuid]["id"] + os.remove(judge_org_prediction_path) + dump_jsonl(judge_preds, judge_org_prediction_path) + + class Eval(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: new_cfg = dict( @@ -138,7 +253,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Starting evaluation tasks...") tasks = partitioner(cfg) - # update tasks cfg before run + # Update tasks cfg before run self._update_tasks_cfg(tasks, cfg) runner = RUNNERS.build(cfg.eval.runner) @@ -151,9 +266,11 @@ def do_work(self, cfg: ConfigDict): logger.info("Evaluation tasks completed.") def _update_tasks_cfg(self, tasks, cfg: ConfigDict): - # update parameters to correct sub cfg - pass - + # Replace default model config to judge model config + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + task["datasets"][0][0].pop("judge_infer_cfg") class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: @@ -171,6 +288,7 @@ def update_cfg(self, cfg: ConfigDict) -> None: def do_work(self, cfg: ConfigDict) -> int: logger.info("Summarizing evaluation results...") summarizer_cfg = cfg.get("summarizer", {}) + cfg = self._cfg_pre_process(cfg) # For subjective summarizer if summarizer_cfg.get("function", None): @@ -203,6 +321,13 @@ def do_work(self, cfg: ConfigDict) -> int: summarizer = build_from_cfg(summarizer_cfg) summarizer.summarize(time_str=self.args.cfg_time_str) + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + for i, dataset in enumerate(cfg.datasets): + if dataset.get("judge_infer_cfg"): + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + cfg.datasets[i].pop("judge_infer_cfg") + return cfg + class PerfViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: @@ -233,9 +358,9 @@ def do_work(self, cfg: ConfigDict) -> int: WORK_FLOW = dict( - all=[Infer, Eval, AccViz], + all=[Infer, JudgeInfer, Eval, AccViz], infer=[Infer], - eval=[Eval, AccViz], + eval=[JudgeInfer, Eval, AccViz], viz=[AccViz], perf=[Infer, PerfViz], perf_viz=[PerfViz], diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py new file mode 100644 index 00000000..c87fb232 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -0,0 +1,117 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content +from ais_bench.benchmark.datasets import ( + Aime2025Dataset, + Aime2025JDGDataset, +) +from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator + + +aime2025_reader_cfg = dict(input_columns=["question"], output_column="answer") + + +aime2025_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt="{question}\nRemember to put your final answer within \\boxed{}.", + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +GRADER_TEMPLATE = """ + Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly. + + Here are some evaluation criteria: + 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct. + 2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question. + 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. + 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. + 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + + Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: + A: CORRECT + B: INCORRECT + Just return the letters "A" or "B", with no text around it. + + Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. + + + : \n{question}\n\n\n + : \n{answer}\n\n\n + : \n{model_answer}\n\n\n + + Judging the correctness of candidates' answers: +""".strip() + +aime2025_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "answer", "model_answer"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + additional_abbr="judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="localhost", + host_port=8080, + url="", + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), + ), + judge_dataset_type=Aime2025JDGDataset, + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role='SYSTEM', + fallback_role='HUMAN', + prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.", + ) + ], + round=[ + dict(role='HUMAN', prompt=GRADER_TEMPLATE), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +aime2025_eval_cfg = dict( + evaluator=dict(type=LLMJudgeCorrectEvaluator), + pred_postprocessor=dict(type=get_a_or_b), +) + +aime2025_datasets = [ + dict( + abbr="aime2025", + type=Aime2025Dataset, + path="ais_bench/datasets/aime2025/aime2025.jsonl", + reader_cfg=aime2025_reader_cfg, + infer_cfg=aime2025_infer_cfg, + judge_infer_cfg=aime2025_judge_infer_cfg, + eval_cfg=aime2025_eval_cfg, + ) +] diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 6e67b07d..548e28b2 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -1,16 +1,15 @@ -import json +import json from datasets import Dataset from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path -from .base import BaseDataset +from .base import BaseDataset, BaseJDGDatasetMethod @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): - @staticmethod def load(path, **kwargs): path = get_data_path(path) @@ -20,3 +19,22 @@ def load(path, **kwargs): line = json.loads(line.strip()) dataset.append(line) return Dataset.from_list(dataset) + +class Aime2025JDGDataset(Aime2025Dataset): + def load(self, path, predictions_path, **kwargs): + + dataset_content = Aime2025Dataset.load(path, **kwargs) + + # 加载被测模型的推理结果(排序后) + predictions: list = BaseJDGDatasetMethod.load_from_predictions(predictions_path) + + # 为数据集添加 model_answer 列 + dataset_list = [] + + for item in predictions: + item_dict = dataset_content[int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + + return Dataset.from_list(dataset_list) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index de062a5d..f3deb8f3 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,3 +1,5 @@ +import os + from abc import abstractmethod from typing import List, Dict, Optional, Union @@ -8,6 +10,7 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.utils.file.file import load_jsonl disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -108,3 +111,19 @@ def test(self): @abstractmethod def load(**kwargs) -> Union[Dataset, DatasetDict]: pass + +class BaseJDGDatasetMethod: + @staticmethod + def load_from_predictions(prediction_path: str) -> Dict: + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_dir (str): The directory containing prediction files. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id',0)) + return preds diff --git a/ais_bench/benchmark/datasets/utils/datasets.py b/ais_bench/benchmark/datasets/utils/datasets.py index 9f473d6a..6bb31701 100644 --- a/ais_bench/benchmark/datasets/utils/datasets.py +++ b/ais_bench/benchmark/datasets/utils/datasets.py @@ -69,7 +69,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count data_list (list): Data list. sample_mode (str): Sample mode. request_count (int): Request count. - + Raises: ValueError: If sample mode is not supported. ValueError: If request count is negative. @@ -101,7 +101,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count return shuffle_data else: raise ValueError(f"Sample mode: {sample_mode} is not supported!") - + def get_meta_json(dataset_path, meta_path): ori_meta_path = meta_path if not meta_path: @@ -389,7 +389,7 @@ def _to_float(text: str): return relative_change <= max_relative_change else: return prediction.lower() == target.lower() - + def anls_compute(groundtruth, prediction): gt_answer = ' '.join(groundtruth.strip().lower().split()) det_answer = ' '.join(prediction.strip().lower().split()) diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py new file mode 100644 index 00000000..a17e72b7 --- /dev/null +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -0,0 +1,36 @@ +import re + +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +logger = AISLogger() + +@TEXT_POSTPROCESSORS.register_module("get_a_or_b") +def get_a_or_b(pred: str) -> str: + """从模型回复中提取A或B""" + match = re.search(r'[AB]', pred) + return match.group(0) if match else 'B' + + +@ICL_EVALUATORS.register_module() +class LLMJudgeCorrectEvaluator(BaseEvaluator): + + def __init__(self): + super().__init__() + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + correct = 0 + count = 0 + details = [] + for i, j in zip(predictions, references): + detail = {'pred': i, 'answer': j, 'correct': False} + count += 1 + if i == "A": + correct += 1 + detail['correct'] = True + details.append(detail) + result = {'accuracy': 100 * correct / count, 'details': details} + return result \ No newline at end of file diff --git a/ais_bench/benchmark/utils/file/file.py b/ais_bench/benchmark/utils/file/file.py index d6bfde67..47f048dc 100644 --- a/ais_bench/benchmark/utils/file/file.py +++ b/ais_bench/benchmark/utils/file/file.py @@ -1,6 +1,8 @@ from typing import List, Tuple, Union import os import json +import mmap +import orjson import fnmatch import tabulate @@ -226,4 +228,31 @@ def check_mm_custom(path): return False if line["type"] not in ["image", "video", "audio"]: return False - return True \ No newline at end of file + return True + +def load_jsonl(path: str) -> List[dict]: + """Load JSONL file into a list of dictionaries. + + Args: + path: Path to the JSONL file + + Returns: + List of dictionaries, each representing a line in the JSONL file + """ + preds = [] + with open(path, "rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + for line in iter(mm.readline, b""): + preds.append(orjson.loads(line)) + return preds + +def dump_jsonl(data: List[dict], path: str): + """Dump a list of dictionaries to a JSONL file. + + Args: + data: List of dictionaries to be dumped + path: Path to the output JSONL file + """ + with open(path, 'wb') as f: + for item in data: + f.write(orjson.dumps(item) + b'\n') \ No newline at end of file From 16a9848d2e55a34ae445ac26dc1d92581d4c1744 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Thu, 5 Feb 2026 09:42:36 +0800 Subject: [PATCH 02/59] reconstruct the judgedatasets --- ais_bench/benchmark/cli/workers.py | 27 ++++++- .../aime2025/aime2025_gen_0_shot_llmjudge.py | 9 ++- ais_bench/benchmark/datasets/aime2025.py | 25 ++---- ais_bench/benchmark/datasets/base.py | 77 ++++++++++++++----- .../benchmark/datasets/utils/llm_judge.py | 22 +++++- 5 files changed, 113 insertions(+), 47 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index f75cce7e..77a5d59d 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,6 +1,7 @@ import os import os.path as osp import copy +import shutil from abc import ABC, abstractmethod from collections import defaultdict @@ -201,7 +202,7 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') if not osp.exists(task["datasets"][0][0]["predictions_path"]): raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' model_abbr = task["models"][0]["abbr"] task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") task["models"][0]["abbr"] = model_abbr @@ -263,15 +264,35 @@ def do_work(self, cfg: ConfigDict): runner(task_part) else: runner(tasks) + self._result_post_process(tasks, cfg) logger.info("Evaluation tasks completed.") def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # Replace default model config to judge model config + self.judge_result_paths = {} for task in tasks: if task["datasets"][0][0].get("judge_infer_cfg"): - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + new_dataset_abbr = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' + org_dataset_abbr = task["datasets"][0][0]["abbr"] + self.judge_result_paths[new_dataset_abbr] = org_dataset_abbr + task["datasets"][0][0]["abbr"] = new_dataset_abbr task["datasets"][0][0].pop("judge_infer_cfg") + def _result_post_process(self, tasks, cfg: ConfigDict): + # Copy judge infer result to normal name + + for task in tasks: + if task["datasets"][0][0]["abbr"] in self.judge_result_paths.keys(): + cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.judge_result_paths[task["datasets"][0][0]["abbr"]]}.jsonl') + if os.path.exists(final_org_results_path): + os.remove(final_org_results_path) + + if os.path.exists(cur_results_path): + # 基于cur_results_path的文件复制一份final_org_results_path + shutil.copy(cur_results_path, final_org_results_path) + + class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: summarizer_cfg = cfg.get("summarizer", {}) @@ -324,7 +345,7 @@ def do_work(self, cfg: ConfigDict) -> int: def _cfg_pre_process(self, cfg: ConfigDict) -> None: for i, dataset in enumerate(cfg.datasets): if dataset.get("judge_infer_cfg"): - cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' cfg.datasets[i].pop("judge_infer_cfg") return cfg diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py index c87fb232..7ece227c 100644 --- a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -38,10 +38,11 @@ 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + 6. If the candidate's answer is semantically incomplete at the end, please judge it as inconsistent. Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: - A: CORRECT - B: INCORRECT + A: Means the answer is consistent with the standard answer. + B: Means the answer is inconsistent with the standard answer. Just return the letters "A" or "B", with no text around it. Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. @@ -51,7 +52,7 @@ : \n{answer}\n\n\n : \n{model_answer}\n\n\n - Judging the correctness of candidates' answers: + Judging the correctness of candidates' answers, please return the the letters "A" or "B" first before your thinking: """.strip() aime2025_judge_infer_cfg = dict( @@ -59,7 +60,7 @@ judge_model=dict( attr="service", type=VLLMCustomAPIChat, - additional_abbr="judge", # Be added after dataset abbr + abbr="judge", # Be added after dataset abbr path="", model="", stream=True, diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 548e28b2..b6b13a1c 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -4,9 +4,9 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset -from .base import BaseDataset, BaseJDGDatasetMethod - +from ais_bench.benchmark.datasets.base import BaseDataset @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): @@ -20,21 +20,8 @@ def load(path, **kwargs): dataset.append(line) return Dataset.from_list(dataset) -class Aime2025JDGDataset(Aime2025Dataset): - def load(self, path, predictions_path, **kwargs): - - dataset_content = Aime2025Dataset.load(path, **kwargs) - - # 加载被测模型的推理结果(排序后) - predictions: list = BaseJDGDatasetMethod.load_from_predictions(predictions_path) - # 为数据集添加 model_answer 列 - dataset_list = [] - - for item in predictions: - item_dict = dataset_content[int(item["id"])] - item_dict["model_answer"] = item["prediction"] - item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold - dataset_list.append(item_dict) - - return Dataset.from_list(dataset_list) +@LOAD_DATASET.register_module() +class Aime2025JDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return Aime2025Dataset diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index f3deb8f3..243a52b0 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,7 +1,5 @@ -import os - from abc import abstractmethod -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Type from datasets import Dataset, DatasetDict from datasets.utils.logging import disable_progress_bar @@ -10,7 +8,6 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError -from ais_bench.benchmark.utils.file.file import load_jsonl disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -109,21 +106,61 @@ def test(self): return self.reader.dataset['test'] @abstractmethod - def load(**kwargs) -> Union[Dataset, DatasetDict]: + def load(self, **kwargs) -> Union[Dataset, DatasetDict]: pass -class BaseJDGDatasetMethod: - @staticmethod - def load_from_predictions(prediction_path: str) -> Dict: - """Load predictions from a directory and merge them with the dataset. - - Args: - prediction_dir (str): The directory containing prediction files. - - Returns: - Dataset: The merged dataset with predictions. - """ - if os.path.exists(prediction_path): - preds = load_jsonl(prediction_path) - preds.sort(key=lambda x: x.get('id',0)) - return preds + +class BaseJDGDataset(BaseDataset): + def __init__(self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, **kwargs) + super().__init__(reader_cfg, k, n, **kwargs) + + def load(self, predictions_path: str, **kwargs): + + dataset_content = self.dataset_instance.dataset["test"] + + # 加载被测模型的推理结果(排序后) + predictions: list = self._load_from_predictions(predictions_path) + + # 为数据集添加 model_answer 列 + if isinstance(dataset_content, Dataset): + dataset_list = [] + for item in predictions: + item_dict = dataset_content[int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + elif isinstance(dataset_content, DatasetDict): + dataset_list = [] + for key in dataset_content: + for item in predictions: + item_dict = dataset_content[key][int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + + return Dataset.from_list(dataset_list) + + @abstractmethod + def _load_from_predictions(self, prediction_path: str) -> Dict: + pass + + @abstractmethod + def _get_dataset_class(self): + return BaseDataset + + def _init_org_datasets_instance( + self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + dataset_class = self._get_dataset_class() + return dataset_class(reader_cfg, k, n, **kwargs) + diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index a17e72b7..8b6b18f0 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,18 +1,38 @@ import re +import os from ais_bench.benchmark.utils.logging import AISLogger from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS) from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl logger = AISLogger() + @TEXT_POSTPROCESSORS.register_module("get_a_or_b") def get_a_or_b(pred: str) -> str: """从模型回复中提取A或B""" - match = re.search(r'[AB]', pred) + match = re.search(r'[AB]', pred[-1:]) return match.group(0) if match else 'B' +class LLMJudgeDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id',0)) + return preds + + @ICL_EVALUATORS.register_module() class LLMJudgeCorrectEvaluator(BaseEvaluator): From df6bc4d3fba220353e31e291e31e2a2619205584 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 11 Feb 2026 16:37:32 +0800 Subject: [PATCH 03/59] reconstruct judgedataset --- ais_bench/benchmark/cli/workers.py | 29 ++++++- .../aime2025/aime2025_gen_0_shot_llmjudge.py | 9 ++- ais_bench/benchmark/datasets/aime2025.py | 25 ++---- ais_bench/benchmark/datasets/base.py | 77 ++++++++++++++----- .../benchmark/datasets/utils/llm_judge.py | 22 +++++- 5 files changed, 114 insertions(+), 48 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index f75cce7e..689762a0 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,6 +1,7 @@ import os import os.path as osp import copy +import shutil from abc import ABC, abstractmethod from collections import defaultdict @@ -79,7 +80,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Merging datasets with the same model and inferencer...") tasks = self._merge_datasets(tasks) - runner = RUNNERS.build(cfg.judge_infer.runner) + runner = RUNNERS.build(cfg.infer.runner) runner(tasks) logger.info("Inference tasks completed.") @@ -201,7 +202,7 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') if not osp.exists(task["datasets"][0][0]["predictions_path"]): raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' model_abbr = task["models"][0]["abbr"] task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") task["models"][0]["abbr"] = model_abbr @@ -263,15 +264,35 @@ def do_work(self, cfg: ConfigDict): runner(task_part) else: runner(tasks) + self._result_post_process(tasks, cfg) logger.info("Evaluation tasks completed.") def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # Replace default model config to judge model config + self.judge_result_paths = {} for task in tasks: if task["datasets"][0][0].get("judge_infer_cfg"): - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + new_dataset_abbr = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' + org_dataset_abbr = task["datasets"][0][0]["abbr"] + self.judge_result_paths[new_dataset_abbr] = org_dataset_abbr + task["datasets"][0][0]["abbr"] = new_dataset_abbr task["datasets"][0][0].pop("judge_infer_cfg") + def _result_post_process(self, tasks, cfg: ConfigDict): + # Copy judge infer result to normal name + + for task in tasks: + if task["datasets"][0][0]["abbr"] in self.judge_result_paths.keys(): + cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.judge_result_paths[task["datasets"][0][0]["abbr"]]}.jsonl') + if os.path.exists(final_org_results_path): + os.remove(final_org_results_path) + + if os.path.exists(cur_results_path): + # 基于cur_results_path的文件复制一份final_org_results_path + shutil.copy(cur_results_path, final_org_results_path) + + class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: summarizer_cfg = cfg.get("summarizer", {}) @@ -324,7 +345,7 @@ def do_work(self, cfg: ConfigDict) -> int: def _cfg_pre_process(self, cfg: ConfigDict) -> None: for i, dataset in enumerate(cfg.datasets): if dataset.get("judge_infer_cfg"): - cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' cfg.datasets[i].pop("judge_infer_cfg") return cfg diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py index c87fb232..7ece227c 100644 --- a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -38,10 +38,11 @@ 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + 6. If the candidate's answer is semantically incomplete at the end, please judge it as inconsistent. Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: - A: CORRECT - B: INCORRECT + A: Means the answer is consistent with the standard answer. + B: Means the answer is inconsistent with the standard answer. Just return the letters "A" or "B", with no text around it. Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. @@ -51,7 +52,7 @@ : \n{answer}\n\n\n : \n{model_answer}\n\n\n - Judging the correctness of candidates' answers: + Judging the correctness of candidates' answers, please return the the letters "A" or "B" first before your thinking: """.strip() aime2025_judge_infer_cfg = dict( @@ -59,7 +60,7 @@ judge_model=dict( attr="service", type=VLLMCustomAPIChat, - additional_abbr="judge", # Be added after dataset abbr + abbr="judge", # Be added after dataset abbr path="", model="", stream=True, diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 548e28b2..b6b13a1c 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -4,9 +4,9 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset -from .base import BaseDataset, BaseJDGDatasetMethod - +from ais_bench.benchmark.datasets.base import BaseDataset @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): @@ -20,21 +20,8 @@ def load(path, **kwargs): dataset.append(line) return Dataset.from_list(dataset) -class Aime2025JDGDataset(Aime2025Dataset): - def load(self, path, predictions_path, **kwargs): - - dataset_content = Aime2025Dataset.load(path, **kwargs) - - # 加载被测模型的推理结果(排序后) - predictions: list = BaseJDGDatasetMethod.load_from_predictions(predictions_path) - # 为数据集添加 model_answer 列 - dataset_list = [] - - for item in predictions: - item_dict = dataset_content[int(item["id"])] - item_dict["model_answer"] = item["prediction"] - item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold - dataset_list.append(item_dict) - - return Dataset.from_list(dataset_list) +@LOAD_DATASET.register_module() +class Aime2025JDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return Aime2025Dataset diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index f3deb8f3..243a52b0 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,7 +1,5 @@ -import os - from abc import abstractmethod -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Type from datasets import Dataset, DatasetDict from datasets.utils.logging import disable_progress_bar @@ -10,7 +8,6 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError -from ais_bench.benchmark.utils.file.file import load_jsonl disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -109,21 +106,61 @@ def test(self): return self.reader.dataset['test'] @abstractmethod - def load(**kwargs) -> Union[Dataset, DatasetDict]: + def load(self, **kwargs) -> Union[Dataset, DatasetDict]: pass -class BaseJDGDatasetMethod: - @staticmethod - def load_from_predictions(prediction_path: str) -> Dict: - """Load predictions from a directory and merge them with the dataset. - - Args: - prediction_dir (str): The directory containing prediction files. - - Returns: - Dataset: The merged dataset with predictions. - """ - if os.path.exists(prediction_path): - preds = load_jsonl(prediction_path) - preds.sort(key=lambda x: x.get('id',0)) - return preds + +class BaseJDGDataset(BaseDataset): + def __init__(self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, **kwargs) + super().__init__(reader_cfg, k, n, **kwargs) + + def load(self, predictions_path: str, **kwargs): + + dataset_content = self.dataset_instance.dataset["test"] + + # 加载被测模型的推理结果(排序后) + predictions: list = self._load_from_predictions(predictions_path) + + # 为数据集添加 model_answer 列 + if isinstance(dataset_content, Dataset): + dataset_list = [] + for item in predictions: + item_dict = dataset_content[int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + elif isinstance(dataset_content, DatasetDict): + dataset_list = [] + for key in dataset_content: + for item in predictions: + item_dict = dataset_content[key][int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + + return Dataset.from_list(dataset_list) + + @abstractmethod + def _load_from_predictions(self, prediction_path: str) -> Dict: + pass + + @abstractmethod + def _get_dataset_class(self): + return BaseDataset + + def _init_org_datasets_instance( + self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + dataset_class = self._get_dataset_class() + return dataset_class(reader_cfg, k, n, **kwargs) + diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index a17e72b7..8b6b18f0 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,18 +1,38 @@ import re +import os from ais_bench.benchmark.utils.logging import AISLogger from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS) from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl logger = AISLogger() + @TEXT_POSTPROCESSORS.register_module("get_a_or_b") def get_a_or_b(pred: str) -> str: """从模型回复中提取A或B""" - match = re.search(r'[AB]', pred) + match = re.search(r'[AB]', pred[-1:]) return match.group(0) if match else 'B' +class LLMJudgeDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id',0)) + return preds + + @ICL_EVALUATORS.register_module() class LLMJudgeCorrectEvaluator(BaseEvaluator): From 312bb1d7cb3c5d384f9cc00aedb3d5ef143bff6b Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 13 Feb 2026 15:12:26 +0800 Subject: [PATCH 04/59] suppport gedit infer --- ais_bench/benchmark/cli/workers.py | 2 +- .../configs/datasets/gedit/gedit_gen.py | 44 +++ .../models/lmm_models/qwen_image_edit.py | 18 + ais_bench/benchmark/datasets/g_edit.py | 94 +++++ ais_bench/benchmark/models/__init__.py | 3 +- .../benchmark/models/local_models/__init__.py | 0 .../benchmark/models/local_models/base.py | 22 +- .../local_models/qwen_image_edit_mindie_sd.py | 335 ++++++++++++++++++ ais_bench/benchmark/models/output.py | 64 +++- .../icl_inferencer/icl_lmm_gen_inferencer.py | 75 ++++ .../icl_inferencer/output_handler/__init__.py | 0 .../output_handler/base_handler.py | 16 +- .../output_handler/bfcl_v3_output_handler.py | 10 +- .../gen_inferencer_output_handler.py | 2 + .../lmm_gen_inferencer_output_handler.py | 72 ++++ .../ppl_inferencer_output_handler.py | 20 +- .../icl_prompt_template_mm.py | 3 +- ais_bench/benchmark/utils/image_process.py | 14 + .../multi_device_run_qwen_image_edit.py | 31 ++ 19 files changed, 803 insertions(+), 22 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py create mode 100644 ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py create mode 100644 ais_bench/benchmark/datasets/g_edit.py create mode 100644 ais_bench/benchmark/models/local_models/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py create mode 100644 ais_bench/benchmark/utils/image_process.py create mode 100644 ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 77a5d59d..689762a0 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -80,7 +80,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Merging datasets with the same model and inferencer...") tasks = self._merge_datasets(tasks) - runner = RUNNERS.build(cfg.judge_infer.runner) + runner = RUNNERS.build(cfg.infer.runner) runner(tasks) logger.info("Inference tasks completed.") diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py new file mode 100644 index 00000000..57509dee --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py @@ -0,0 +1,44 @@ +from ais_bench.benchmark.openicl.icl_prompt_template.icl_prompt_template_mm import MMPromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer +from ais_bench.benchmark.datasets.g_edit import GEditDataset, GEditEvaluator + + +gedit_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='task_type' +) + + +gedit_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role="HUMAN", prompt_mm={ + "text": {"type": "text", "text": "{question}"}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=LMMGenInferencer) +) + +gedit_eval_cfg = dict( + evaluator=dict(type=GEditEvaluator) +) + +gedit_datasets = [ + dict( + abbr='gedit', + type=GEditDataset, + path='ais_bench/datasets/GEdit-Bench', # 数据集路径,使用相对路径时相对于源码根路径,支持绝对路径 + split_count=1, + split_index=0, + reader_cfg=gedit_reader_cfg, + infer_cfg=gedit_infer_cfg, + eval_cfg=gedit_eval_cfg + ) +] diff --git a/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py new file mode 100644 index 00000000..c3e92dc4 --- /dev/null +++ b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py @@ -0,0 +1,18 @@ +from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel + +models = [ + dict( + attr="local", # local or service + type=QwenImageEditModel, # transformers >= 4.33.0 用这个,prompt 是构造成对话格式 + abbr='qwen-image-edit', + path='/home/yanhe/models/Qwen-Image-Edit-2509/', # path to model dir, current value is just a example + device_kwargs=dict( + ), + infer_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained + num_inference_steps=50, + num_images_per_prompt=1, + ), + run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务 + batch_size=1, # 每次推理的batch size + ) +] \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py new file mode 100644 index 00000000..9d9224b6 --- /dev/null +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -0,0 +1,94 @@ +import json +from datasets import Dataset, load_from_disk, concatenate_datasets +from concurrent.futures import ThreadPoolExecutor, as_completed + +from ais_bench.benchmark.registry import LOAD_DATASET +from ais_bench.benchmark.openicl import BaseEvaluator +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset +from ais_bench.benchmark.utils.image_process import pil_to_base64 +from PIL import Image +from tqdm import tqdm + +from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + +GEDIT_COUNT = 10 + +class GEditEvaluator(BaseEvaluator): + def score(self, predictions, references): + details = [] + for i, pred in enumerate(predictions): + details.append({ + 'pred': pred, + 'ref': references[i], + }) + result = {'accuracy': 100 * len(predictions) / len(references), 'details': details} + return result + +@LOAD_DATASET.register_module() +class GEditDataset(BaseDataset): + @staticmethod + def load(path, use_raw=False, split_count=1, split_index=0, **kwargs): + path = get_data_path(path) + dataset = load_from_disk(path) + + # 数据集切分:分成 split_count 份,取第 split_index 份 + if split_count > 1: + total_len = len(dataset) + base_size = total_len // split_count # 每份基础大小 + remainder = total_len % split_count # 余数 + + # 计算当前 split_index 的起始和结束位置 + # 前 remainder 份每份多一个元素 + if split_index < remainder: + start_idx = split_index * (base_size + 1) + end_idx = start_idx + (base_size + 1) + else: + start_idx = remainder * (base_size + 1) + (split_index - remainder) * base_size + end_idx = start_idx + base_size + + dataset = dataset.select(range(start_idx, end_idx)) + else: + dataset = dataset.select(range(GEDIT_COUNT)) + + if use_raw: + image_column = 'input_image_raw' + else: + image_column = 'input_image' + + def process_example_to_dataset(example): + """处理单条数据并转换为 Dataset""" + image_url = pil_to_base64(example[image_column], "PNG") + example['content'] = AIS_IMAGE_START + image_url + AIS_CONTENT_TAG \ + + AIS_TEXT_START + example['instruction'] + AIS_CONTENT_TAG + # 使用 from_dict 替代 from_list 以提高性能 + data_dict = {key: [example[key]] for key in example.keys()} + return Dataset.from_dict(data_dict) + + max_workers = 4 # Adjust based on system resources + processed_datasets = [None] * len(dataset) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 + with tqdm(total=len(dataset), desc=f"Submitting tasks split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: + futures = {} + for i, example in enumerate(dataset): + future = executor.submit(process_example_to_dataset, example) + futures[future] = i + submit_pbar.update(1) + + # 收集处理完成的 Dataset + with tqdm(total=len(dataset), desc="Processing GEdit dataset", unit="example") as pbar: + for future in as_completed(futures): + idx = futures[future] + processed_datasets[idx] = future.result() + pbar.update(1) + + # 合并所有 Dataset + return concatenate_datasets(processed_datasets) + +@LOAD_DATASET.register_module() +class GEditJDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return GEditDataset \ No newline at end of file diff --git a/ais_bench/benchmark/models/__init__.py b/ais_bench/benchmark/models/__init__.py index 5908d946..12230bf1 100644 --- a/ais_bench/benchmark/models/__init__.py +++ b/ais_bench/benchmark/models/__init__.py @@ -14,4 +14,5 @@ from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401 -from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel \ No newline at end of file +from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel +from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/__init__.py b/ais_bench/benchmark/models/local_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/base.py b/ais_bench/benchmark/models/local_models/base.py index b2766aab..74d43283 100644 --- a/ais_bench/benchmark/models/local_models/base.py +++ b/ais_bench/benchmark/models/local_models/base.py @@ -57,7 +57,7 @@ def __init__(self, self.is_synthetic = False @abstractmethod - def _generate(self, input, max_out_len: int) -> List[str]: + def generate(self, inputs, max_out_len: int) -> List[str]: """Generate result given a input. Args: @@ -133,17 +133,6 @@ def parse_template(self, prompt_template: PromptType, mode: str) -> str: """ return self.template_parser.parse_template(prompt_template, mode) - def generate_from_template(self, templates: List[PromptType], **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates, mode='gen') - max_out_lens = kwargs.get("max_out_lens", [None] * len(templates)) - return self.generate(inputs, max_out_lens, **kwargs) - def get_token_len_from_template( self, templates: Union[PromptType, List[PromptType]], @@ -204,6 +193,15 @@ def sync_inputs(self, inputs: str) -> str: def to(self, device): self.model.to(device) +class BaseLMModel(BaseModel): + """Base class for language models. + """ + def generate(self, inputs, outputs, **kwargs) -> List[str]: + raise AISBenchNotImplementedError( + MODEL_CODES.UNKNOWN_ERROR, + f'{self.__class__.__name__} does not supported' + ' to be called in base classes') + class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py new file mode 100644 index 00000000..b4115533 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -0,0 +1,335 @@ +# flake8: noqa +# yapf: disable +import os +import time +from typing import Dict, List, Optional, Union + +import torch +import torch_npu +import base64 +import io +from PIL import Image + +from ais_bench.benchmark.models.local_models.base import BaseLMModel +from ais_bench.benchmark.registry import MODELS +from ais_bench.benchmark.utils.prompt import PromptList +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES +from ais_bench.benchmark.models.local_models.huggingface_above_v4_33 import (_convert_chat_messages, + _get_meta_template, + ) + +# 解决 diffuser 0.35.1 torch2.1 报错 +def custom_op( + name, + fn=None, + /, + *, + mutates_args, + device_types=None, + schema=None, + tags=None, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +def register_fake( + op, + fn=None, + /, + *, + lib=None, + _stacklevel: int = 1, + allow_override: bool = False, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +if hasattr(torch, 'library'): + torch.library.custom_op = custom_op + torch.library.register_fake = register_fake + +# 导入 qwen_image_edit 相关模块 +try: + from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel + from ais_bench.benchmark.models.local_models.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline + from mindiesd import CacheConfig, CacheAgent +except ImportError as e: + raise ImportError(f"请确保 qwenimage_edit 模块在 Python 路径中: {e}") + +PromptType = Union[PromptList, str] + +# 模型推理相关配置常量 +DEFAULT_MODEL_PATH = "/home/yanhe/models/Qwen-Image-Edit-2509/" +DEFAULT_TORCH_DTYPE = "bfloat16" +DEFAULT_DEVICE = "npu" +DEFAULT_DEVICE_ID = 0 +DEFAULT_NUM_INFERENCE_STEPS = 1 # 40 +DEFAULT_TRUE_CFG_SCALE = 4.0 +DEFAULT_GUIDANCE_SCALE = 1.0 +DEFAULT_SEED = 0 +DEFAULT_NUM_IMAGES_PER_PROMPT = 1 +DEFAULT_QUANT_DESC_PATH = None + +# 缓存配置开关 +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) + + +@MODELS.register_module() +class QwenImageEditModel(BaseLMModel): + """Model wrapper for Qwen-Image-Edit-2509 models. + + Args: + path (str): The path to the model. + model_kwargs (dict): Additional model arguments. + sample_kwargs (dict): Additional sampling arguments. + vision_kwargs (dict): Additional vision arguments. + meta_template (Optional[Dict]): The model's meta prompt template. + """ + + def __init__(self, + path: str = DEFAULT_MODEL_PATH, + device_kwargs: dict = dict(), + infer_kwargs: dict = dict(), + meta_template: Optional[Dict] = None, + **other_kwargs): + self.logger = AISLogger() + self.path = path + self.max_out_len = other_kwargs.get('max_out_len', None) + self.template_parser = _get_meta_template(meta_template) + + # 设备配置 + self.device = device_kwargs.get('device', DEFAULT_DEVICE) + #self.device_id = device_kwargs.get('device_id', DEFAULT_DEVICE_ID) + # 在这里声明环境变量 + self.logger.debug(f"device id from kwargs: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}") + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = f"{device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}" + self.device_id = DEFAULT_DEVICE_ID + self.device_str = f"{self.device}:{DEFAULT_DEVICE_ID}" + self.logger.debug(f"device_str: {self.device_str}; device_id: {self.device_id}") + self.logger.debug(f"ASCEND_RT_VISIBLE_DEVICES: {os.getenv('ASCEND_RT_VISIBLE_DEVICES')}") + + # 模型配置 + self.torch_dtype = other_kwargs.get('torch_dtype', DEFAULT_TORCH_DTYPE) + self.torch_dtype = torch.bfloat16 if self.torch_dtype == "bfloat16" else torch.float32 + + # 推理配置 + self.num_inference_steps = infer_kwargs.get('num_inference_steps', DEFAULT_NUM_INFERENCE_STEPS) + self.true_cfg_scale = infer_kwargs.get('true_cfg_scale', DEFAULT_TRUE_CFG_SCALE) + self.guidance_scale = infer_kwargs.get('guidance_scale', DEFAULT_GUIDANCE_SCALE) + self.seed = infer_kwargs.get('seed', DEFAULT_SEED) + self.num_images_per_prompt = infer_kwargs.get('num_images_per_prompt', DEFAULT_NUM_IMAGES_PER_PROMPT) + self.quant_desc_path = infer_kwargs.get('quant_desc_path', DEFAULT_QUANT_DESC_PATH) + + # 加载模型 + self._load_model() + + # 缓存配置 + if COND_CACHE or UNCOND_CACHE: + # 保守cache + cache_config = CacheConfig( + method="dit_block_cache", + blocks_count=60, + steps_count=self.num_inference_steps, + step_start=10, + step_interval=3, + step_end=35, + block_start=10, + block_end=50 + ) + self.pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None + self.pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None + self.logger.info("启用缓存配置") + + def _load_model(self): + """加载模型""" + self.logger.info(f"从 {self.path} 加载模型...") + + # 设置设备 + if self.device == "npu": + torch.npu.set_device(self.device_id) + + # 加载 transformer + transformer = QwenImageTransformer2DModel.from_pretrained( + os.path.join(self.path, 'transformer'), + torch_dtype=self.torch_dtype, + device_map=None, # 禁用自动设备映射 + low_cpu_mem_usage=True # 启用CPU低内存模式 + ) + + # 量化配置 + if self.quant_desc_path: + from mindiesd import quantize + self.logger.info("Quantizing Transformer (单独量化核心组件)...") + quantize( + model=transformer, + quant_des_path=self.quant_desc_path, + use_nz=True, + ) + if self.device == "npu": + torch.npu.empty_cache() # 清理NPU显存缓存 + + # 加载 pipeline + self.pipeline = QwenImageEditPlusPipeline.from_pretrained( + self.path, + transformer=transformer, + torch_dtype=self.torch_dtype, + device_map=None, + low_cpu_mem_usage=True + ) + + # VAE优化配置(避免显存溢出) + self.pipeline.vae.use_slicing = True + self.pipeline.vae.use_tiling = True + + # 移动模型到目标设备 + self.pipeline.to(self.device_str) + self.pipeline.set_progress_bar_config(disable=None) # 显示进度条 + + def _get_meta_template(self, meta_template): + """获取元模板""" + class DummyTemplateParser: + def parse_template(self, prompt_template, mode): + return prompt_template + return DummyTemplateParser() + + def _generate(self, input) -> List[Image]: + """Generate result given a input. + + Args: + input (PromptType): A string or PromptDict. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + # 处理输入格式 + images = [] + prompts = [] + neg_prompts = [] + print(f"in _generate") + #self.logger.info(f"输入: {input}") + if isinstance(input, str): + prompts.append(input) + neg_prompts.append("") + elif isinstance(input, list): + # 处理包含图像的输入 + for item in input[0]["prompt"]: + if item["type"] == "image_url": + base64_url = item["image_url"]["url"].split(",")[1] + img = Image.open(io.BytesIO(base64.b64decode(base64_url))).convert("RGB") + images.append(img) + elif item["type"] == "text": + prompts.append(item["text"]) + neg_prompts.append("") + else: + prompts.append("") + neg_prompts.append("") + + # 如果没有图像输入,使用默认图像 + if not images: + raise ValueError("QwenImageEditModel requires image input") + + # 执行推理 + results = [] + for prompt, neg_prompt in zip(prompts, neg_prompts): + # 准备输入参数 + print("in _generate loop") + inputs = { + "image": images, + "prompt": prompt, + "negative_prompt": neg_prompt, + "generator": torch.Generator(device=self.device_str).manual_seed(self.seed), + "true_cfg_scale": self.true_cfg_scale, + "guidance_scale": self.guidance_scale, + "num_inference_steps": self.num_inference_steps, + "num_images_per_prompt": self.num_images_per_prompt, + } + + # 执行推理并计时 + if self.device == "npu": + torch.npu.synchronize() # 昇腾设备同步 + start_time = time.time() + + with torch.inference_mode(): + output = self.pipeline(**inputs) + + if self.device == "npu": + torch.npu.synchronize() + end_time = time.time() + infer_time = end_time - start_time + self.logger.info(f"推理完成,耗时: {infer_time:.2f}秒") + + return output + + def encode(self, prompt: str) -> torch.Tensor: + """Encode prompt to tokens. Not necessary for most cases. + + Args: + prompt (str): Input string. + + Returns: + torch.Tensor: Encoded tokens. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `encode` method.') + + def decode(self, tokens: torch.Tensor) -> str: + """Decode tokens to text. Not necessary for most cases. + + Args: + tokens (torch.Tensor): Input tokens. + + Returns: + str: Decoded text. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `decode` method.') + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + # 对于图像编辑模型,token长度计算可能不同,这里返回一个默认值 + return len(prompt.split()) + + def generate(self, inputs, outputs, **kwargs): + """Generate completion from inputs. + + Args: + inputs: Inputs for generation. + max_out_lens: Maximum output lengths. + **kwargs: Additional keyword arguments. + + Returns: + List[str]: Generated completions. + """ + #self.logger.info(f"model {inputs=}") + if not isinstance(inputs, list): + inputs = [inputs] + + for i, input in enumerate(inputs): + result = self._generate(input) + # result is QwenImagePipelineOutput with 'images' attribute + if hasattr(result, 'images') and result.images: + outputs[i].success = True + outputs[i].content = result.images # 将图像列表赋值给 content + else: + outputs[i].success = False + outputs[i].content = [""] diff --git a/ais_bench/benchmark/models/output.py b/ais_bench/benchmark/models/output.py index f0676ae3..935466f5 100644 --- a/ais_bench/benchmark/models/output.py +++ b/ais_bench/benchmark/models/output.py @@ -1,5 +1,9 @@ +import os import time from abc import abstractmethod +from typing import Union + +from PIL import Image import numpy as np @@ -174,4 +178,62 @@ def update_extra_details_data_from_text_response(self, text_response: dict) -> N for item in text_response.get("choices", []): message = item.get("message", {}) self.extra_details_data["message"] = message - return # only one message is allowed \ No newline at end of file + return # only one message is allowed + + +LLM_META_DATA_TYPE = Union[Image, str] + + +class LMMOutput(Output): + def __init__(self, perf_mode: bool = False) -> None: + super().__init__(perf_mode) + self.content: list[LLM_META_DATA_TYPE] = [""] + self.HANDLER_MAP = { + Image.Image: self._handle_image, + str: self._handle_text, + } + + def get_prediction(self, save_dir: str) -> dict: + """Get the final prediction by combining content and reasoning. + + Returns: + dict: Combined prediction content + """ + output = [] + for i, item in enumerate(self.content): + output.append(self.HANDLER_MAP[type(item)](save_dir, i)) + if len(output) == 1: + return output[0] + else: + return output + + def _handle_image(self, save_dir: str, index: int) -> str: + """Handle image content. + + Args: + save_dir: Directory to save image + index: Index of image in content list + + Returns: + str: Last two levels of image path + """ + image = self.content[index] + image_path = os.path.join(save_dir, f"image_{self.uuid}_{index}.png") + if os.path.exists(image_path): + os.remove(image_path) + image.save(image_path) + return os.path.join(*image_path.split(os.sep)[-2:]) + + def _handle_text(self, save_dir: str, index: int) -> str: + """Handle text content. + + Args: + save_dir: Directory to save text + index: Index of text in content list + + Returns: + str: Text content + """ + return self.content[index] + + diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..48604ba9 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py @@ -0,0 +1,75 @@ +''' +Author: SJTUyh yh_silence@alumni.sjtu.edu.cn +Date: 2026-02-11 16:38:01 +LastEditors: SJTUyh yh_silence@alumni.sjtu.edu.cn +LastEditTime: 2026-02-12 18:39:02 +FilePath: \benchmark\ais_bench\benchmark\openicl\icl_inferencer\icl_lmm_gen_inferencer.py +Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE +''' +import uuid +from typing import List, Optional + +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.registry import ICL_INFERENCERS +from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever +from ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import LMMGenInferencerOutputHandler + + +@ICL_INFERENCERS.register_module() +class LMMGenInferencer(GenInferencer): + def __init__( + self, + model_cfg, + stopping_criteria: List[str] = [], + batch_size: Optional[int] = 1, + mode: Optional[str] = "infer", + gen_field_replace_token: Optional[str] = "", + output_json_filepath: Optional[str] = "./icl_inference_output", + save_every: Optional[int] = 1, + **kwargs, + ) -> None: + super().__init__( + model_cfg=model_cfg, + stopping_criteria=stopping_criteria, + batch_size=batch_size, + mode=mode, + gen_field_replace_token=gen_field_replace_token, + output_json_filepath=output_json_filepath, + save_every=save_every, + **kwargs, + ) + + self.output_handler = LMMGenInferencerOutputHandler(perf_mode=self.perf_mode, + save_every=self.save_every) + def inference(self, retriever: BaseRetriever, output_json_filepath: Optional[str] = None) -> List: + self.output_handler.set_output_path(output_json_filepath) + return super().inference(retriever, output_json_filepath) + + def batch_inference( + self, + datum, + ) -> None: + """Perform batch inference on the given dataloader. + + Args: + dataloader: DataLoader containing the inference data + + Returns: + List of inference results + """ + indexs = datum.pop("index") + inputs = datum.pop("prompt") + data_abbrs = datum.pop("data_abbr") + outputs = [LMMOutput(self.perf_mode) for _ in range(len(indexs))] + for output in outputs: + output.uuid = str(uuid.uuid4()).replace("-", "") + golds = datum.pop("gold", [None] * len(inputs)) + self.model.generate(inputs, outputs, **datum) + + for index, input, output, data_abbr, gold in zip( + indexs, inputs, outputs, data_abbrs, golds + ): + self.output_handler.report_cache_info_sync( + index, input, output, data_abbr, gold + ) \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py index 42cef866..9d2a450b 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py @@ -56,7 +56,14 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: self.save_every = save_every @abstractmethod - def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, Output], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + + ) -> dict: """ Get the prediction result. @@ -64,7 +71,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference - + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result """ @@ -74,6 +81,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] def get_result( self, conn: sqlite3.Connection, + data_abbr: str, input: Union[str, List[str]], output: Union[str, Output], gold: Optional[str] = None, @@ -113,7 +121,7 @@ def get_result( if gold: result_data["gold"] = gold else: - result_data = self.get_prediction_result(output, gold=gold, input=input) + result_data = self.get_prediction_result(output, gold=gold, input=input, data_abbr=data_abbr) if not result_data.get("success", True): self.all_success = False if isinstance(output, Output) and hasattr(output, "error_info"): @@ -365,7 +373,7 @@ def run_cache_consumer( try: uid = str(uuid.uuid4())[:8] - result_data = self.get_result(conn, *item[2:]) + result_data = self.get_result(conn, *item[1:]) id, data_abbr = item[0], item[1] json_data = { "data_abbr": data_abbr, diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py index 0b47eec6..8bf8fbb7 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py @@ -11,7 +11,13 @@ class BFCLV3OutputHandler(BaseInferencerOutputHandler): """ Output handler for BFCLV3 inference tasks. """ - def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: FunctionCallOutput, + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: """ Get the prediction result for BFCLV3 inference tasks. @@ -19,6 +25,8 @@ def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] output (FunctionCallOutput): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference (not used in this implementation) + data_abbr (Optional[str]): Abbreviation of the dataset (not used in this implementation) + Returns: dict: Prediction result containing success, uuid, prediction (tool_calls), and inference_log Raises: diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 111799d2..726f841c 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -27,6 +27,7 @@ def get_prediction_result( output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", ) -> dict: """ Get the prediction result for accuracy mode. @@ -35,6 +36,7 @@ def get_prediction_result( output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..70a05336 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -0,0 +1,72 @@ +from typing import List, Optional, Union +import sqlite3 +import uuid +from pathlib import Path + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES +from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError + +class LMMGenInferencerOutputHandler(BaseInferencerOutputHandler): + """ + Output handler for generation-based inference tasks. + + This handler specializes in processing generation model outputs, + supporting both performance measurement and accuracy evaluation modes. + It handles different data formats and provides appropriate result storage. + + Attributes: + all_success (bool): Flag indicating if all operations were successful + perf_mode (bool): Whether in performance measurement mode + cache_queue (queue.Queue): Queue for caching results before writing + """ + def set_output_path(self, output_path: str) -> None: + self.output_path = output_path + + def get_prediction_result( + self, + output: Union[str, LMMOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", + ) -> dict: + """ + Get the prediction result for accuracy mode. + + Args: + output (Union[str, LMMOutput]): Output result from inference + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ + try: + save_dir = Path(self.output_path) / f"{data_abbr}_out_file" + if not save_dir.exists(): + save_dir.mkdir(parents=True, exist_ok=True) + for item in input[0]['prompt']: + if item.get('image_url'): + item['image_url']['url'] = item['image_url']['url'][:256] + result_data = { + "success": ( + output.success if isinstance(output, LMMOutput) else True + ), + "uuid": output.uuid if isinstance(output, LMMOutput) else str(uuid.uuid4()).replace("-", ""), + "origin_prompt": input if input is not None else "", + "prediction": ( + output.get_prediction(save_dir) + if isinstance(output, LMMOutput) + else output + ), + } + if gold: + result_data["gold"] = gold + except Exception as e: + import traceback + print(f"[ERROR] LMMGenInferencerOutputHandler.get_prediction_result failed: {type(e).__name__}: {e}") + print(f"[ERROR] Traceback: {traceback.format_exc()}") + raise RuntimeError(f"Failed to get prediction result: {e}") + return result_data \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py index bf5ac30e..44da5be5 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py @@ -46,7 +46,25 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: super().__init__(save_every) self.perf_mode = perf_mode - def get_prediction_result(self, output: Union[str, PPLResponseOutput], gold: Optional[str] = None, input: Union[str, List[str]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, PPLResponseOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: + """ + Get the prediction result for performance mode. + + Args: + output (Union[str, PPLResponseOutput]): Model output + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ if not isinstance(output, PPLResponseOutput): raise AISBenchImplementationError(ICLI_CODES.UNKNOWN_ERROR, f"Output is not a PPLResponseOutput") result_data = { diff --git a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py index 83d42a4e..34562b5f 100644 --- a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py +++ b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py @@ -39,7 +39,7 @@ def check_mm_template(self): if "prompt_mm" not in data.keys(): return False return True - + def format_mm_url(self, template, entry): """ for mm_custom dataset @@ -103,6 +103,7 @@ def generate_item( template = self.format_mm_url(self.template, entry) template = self._encode_template(template, ice=False) template = template.format_mm(**entry) + for i, item in enumerate(template): if "prompt_mm" in item: template[i]["prompt_mm"] = self.get_mm_template(item) diff --git a/ais_bench/benchmark/utils/image_process.py b/ais_bench/benchmark/utils/image_process.py new file mode 100644 index 00000000..0863dcdf --- /dev/null +++ b/ais_bench/benchmark/utils/image_process.py @@ -0,0 +1,14 @@ +import base64 +from io import BytesIO +from PIL import Image + +def pil_to_base64(image, format="JPEG"): + """ + Convert PIL Image to base64 string + """ + if not isinstance(image, Image.Image): + raise ValueError("Input must be a PIL Image object") + buffered = BytesIO() + image.save(buffered, format) + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str \ No newline at end of file diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py new file mode 100644 index 00000000..467ed0b3 --- /dev/null +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -0,0 +1,31 @@ +from mmengine.config import read_base + +with read_base(): + from ais_bench.benchmark.configs.models.lmm_models.qwen_image_edit import models as qwen_image_edit_models + from ais_bench.benchmark.configs.summarizers.example import summarizer + from ais_bench.benchmark.configs.datasets.gedit.gedit_gen import gedit_datasets + +device_list = [0, 1, 2, 3] + +datasets = [] +models = [] +model_dataset_combinations = [] + +for i in device_list: + model_config = {k: v for k, v in qwen_image_edit_models[0].items()} + model_config['abbr'] = f"{model_config['abbr']}-{i}" + model_config['device_kwargs'] = dict(model_config['device_kwargs']) + model_config['device_kwargs']['device_id'] = i + models.append(model_config) + + dataset_config = {k: v for k, v in gedit_datasets[0].items()} + dataset_config['abbr'] = f"{dataset_config['abbr']}-{i}" + dataset_config['split_count'] = len(device_list) + dataset_config['split_index'] = i + datasets.append(dataset_config) + + # 关键:为每个设备创建一个独立的 model-dataset 组合 + model_dataset_combinations.append({ + 'models': [model_config], # 只包含当前模型 + 'datasets': [dataset_config] # 只包含当前数据集 + }) \ No newline at end of file From aae408cbe41576689d1174940290eb35e14ce040 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 13 Feb 2026 15:15:23 +0800 Subject: [PATCH 05/59] add qwen image edit dep --- .../local_models/qwenimage_edit/__init__.py | 0 .../local_models/qwenimage_edit/attn_layer.py | 201 ++++ .../qwenimage_edit/distributed/__init__.py | 0 .../qwenimage_edit/distributed/all_to_all.py | 156 +++ .../distributed/group_coordinator.py | 640 ++++++++++++ .../distributed/parallel_mgr.py | 404 ++++++++ .../qwenimage_edit/distributed/utils.py | 152 +++ .../pipeline_qwenimage_edit_plus.py | 964 ++++++++++++++++++ .../scheduling_flow_match_euler_discrete.py | 563 ++++++++++ .../qwenimage_edit/transformer_qwenimage.py | 792 ++++++++++++++ 10 files changed, 3872 insertions(+) create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py new file mode 100644 index 00000000..2d0e58e7 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py @@ -0,0 +1,201 @@ +import torch +from torch import Tensor +import torch_npu + +import torch.distributed as dist +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") + + +import math +import os +from typing import Any + +from mindiesd import attention_forward + + + +# from yunchang.comm.all_to_all import SeqAllToAll4D +# from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D +import logging + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group +) + + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + + +class xFuserLongContextAttention_new4(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + use_sync: bool = False, + attn_type: AttnType = AttnType.FA, + attn_processor: torch.nn.Module = None, + q_descale=None, + k_descale=None, + v_descale=None, + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + attn_type: AttnType = AttnType.FA, the attention type supported inside long context attention, including "TORCH", "FA", "FA3", "SAGE_FP16", "SAGE_FP8" + attn_processor: nn.Module = None, the attention processor can be passed in to replace the attention processor if attn_type is do not support it. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + use_sync=use_sync, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + + # 校验:仅"basic"类型的环形实现支持KV缓存 + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + + self.attn_processor = attn_processor + + @torch.compiler.disable + def forward( + self, + attn, + query: Tensor, # [B, S_image/ulysses_size, H, D] + key: Tensor, + value: Tensor, + *, + joint_tensor_query=None, # [B, S_text, H, D] + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + txt_pad_len = 0 + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + + + sp_world_size = get_sequence_parallel_world_size() # USP + sp_rank = get_sequence_parallel_rank() + + + + joint_tensor_query = torch.chunk(joint_tensor_query, sp_world_size, dim=2)[sp_rank] # [B, S_text, H, D] --> [B, S_text, H/ulysses_size, D] + joint_tensor_key = torch.chunk(joint_tensor_key, sp_world_size, dim=2)[sp_rank] + joint_tensor_value = torch.chunk(joint_tensor_value, sp_world_size, dim=2)[sp_rank] + + + + # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) + # scatter 2, gather 1 + if self.use_pack_qkv: + # (3*bs, seq_len/N, head_cnt, head_size) + qkv = torch.cat([query, key, value]).contiguous() + # (3*bs, seq_len, head_cnt/N, head_size) + qkv = SeqAllToAll4D.apply( + self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, + ) + qkv = torch.chunk(qkv, 3, dim=0) + query_layer, key_layer, value_layer = qkv + + else: + # 非打包模式:分别对Q/K/V进行通信拆分 + query_layer = SeqAllToAll4D.apply( + self.ulysses_pg, query, self.scatter_idx, self.gather_idx , # [B, S_image/ulysses_size, H, D] --> [B, S_image, H/ulysses_size, D] + ) + key_layer = SeqAllToAll4D.apply( + self.ulysses_pg, key, self.scatter_idx, self.gather_idx, + ) + value_layer = SeqAllToAll4D.apply( + self.ulysses_pg, value, self.scatter_idx, self.gather_idx, + ) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([joint_tensor_query, query_layer], dim=1) # (B, S_txt + S_img, H/ulysses_size, D_head) + joint_key = torch.cat([joint_tensor_key, key_layer], dim=1) + joint_value = torch.cat([joint_tensor_value, value_layer], dim=1) + + + out = attention_forward( + joint_query, + joint_key, + joint_value, + opt_mode="manual", + op_type="fused_attn_score", + layout="BNSD" + ) + + if type(out) == tuple: + context_layer, _, _ = out + else: + context_layer = out + + txt_seq_len = joint_tensor_query.shape[1] + + text_out = context_layer[:, :txt_seq_len, :, :].contiguous() # 强制连续 + image_out = context_layer[:, txt_seq_len:, :, :].contiguous() + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + image_out = SeqAllToAll4D.apply( + self.ulysses_pg, image_out, self.gather_idx, self.scatter_idx # [B, S_image, H/ulysses_size, D] --> [B, S_image/ulysses_size, H, D] + ) + + text_out = get_sp_group().all_gather(text_out, dim=2) # (B, S_txt , H/ulysses_size, D_head) --> (B, S_txt , H, D_head) + + output = torch.cat([text_out, image_out], dim=1) + # out e.g., [s/p::h] + return output + diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py new file mode 100644 index 00000000..2ffea2f1 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +import torch.distributed as dist + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + # 分支 1:scatter_idx=2 且 gather_idx=1(Ulysses 并行的 “拆分多头” 场景),按「多头维度(dim2)」拆分张量,同时将「序列维度(dim1)」重组为完整长度。 + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + comm_output = comm_output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + return comm_output + + return getter + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + # 分支 2:scatter_idx=1 且 gather_idx=2(Ulysses 并行的 “合并多头” 场景),与分支 1 相反,按「序列维度(dim1)」拆分张量,同时将「多头维度(dim2)」重组为完整多头数。 + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + comm_output = comm_output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + return comm_output + + return getter + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply( + ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync + ), + None, + None, + None, + ) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py new file mode 100644 index 00000000..c48d22a6 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py @@ -0,0 +1,640 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import logging + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if "%" in key: + logging.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + + # 原代码(导致超时) + # cpu_group = torch.distributed.new_group(ranks, backend="gloo") + + # 修改后(使用HCCL后端) + cpu_group = torch.distributed.new_group(ranks, backend="hccl") # 适配昇腾环境 + + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False, async_op: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + + # All-gather. + if async_op: + current_input_size = input_size.copy() # 复制列表 + current_world_size = world_size + current_dim = dim + current_separate_tensors = separate_tensors + comm = torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group, async_op=async_op + ) + + def getter(): + comm.wait() + nonlocal output_tensor # 声明为非局部变量 + + if current_dim != 0: + # 使用捕获的变量,而不是外部变量 + temp_size = current_input_size + temp_size[0] //= current_world_size + output_tensor = output_tensor.reshape([current_world_size] + temp_size) + output_tensor = output_tensor.movedim(0, current_dim) + + if current_separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(current_world_size) + ] + return tensor_list + else: + current_input_size[current_dim] = current_input_size[current_dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(current_input_size) + return output_tensor + + return getter + else: + torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_group = kwargs.get("ulysses_group", None) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + + self.ring_group = kwargs.get("ring_group", None) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py new file mode 100644 index 00000000..0b6ef343 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py @@ -0,0 +1,404 @@ +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +import logging +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +#--------- ljf ------------------- +import torch +import torch.distributed +try: + import torch_musa + from torch_musa.core.device import set_device, device_count +except ModuleNotFoundError: + pass +#--------------------------- + +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + + +@dataclass +class ParallelConfig: + tp_degree: int = 1 + sp_degree: int = 1 + ulysses_degree: int = 1 + ring_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: + logging.error( + "tp_degree * sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): + logging.error("world_size must be divisible by tp_degree * sp_degree * cfg_degree") + + +# * QUERY +def get_world_group() -> GroupCoordinator: + if _WORLD is None: + logging.error("world group is not initialized") + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + if _SP is None: + logging.error("pipeline model parallel group is not initialized") + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + if _CFG is None: + logging.error("classifier_free_guidance parallel group is not initialized") + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +# wan2.1 的 +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logging.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + if distributed_init_method is None: + logging.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('LOCAL_RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + if not _WORLD.world_size == dist.get_world_size(): + logging.error("world group already initialized with a different world size") + + +# def init_distributed_environment( +# world_size: int = -1, +# rank: int = -1, +# distributed_init_method: str = "env://", +# local_rank: int = -1, +# backend: str = "hccl", +# ): +# logging.debug( +# "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", +# world_size, +# rank, +# local_rank, +# distributed_init_method, +# backend, +# ) +# if not torch.distributed.is_initialized(): +# assert distributed_init_method is not None, ( +# "distributed_init_method must be provided when initializing " +# "distributed environment" +# ) +# # this backend is used for WORLD +# torch.distributed.init_process_group( +# backend=backend, +# init_method=distributed_init_method, +# world_size=world_size, +# rank=rank, +# ) +# set_device(torch.distributed.get_rank() % device_count()) +# # set the local rank +# # local_rank is not available in torch ProcessGroup, +# # see https://github.com/pytorch/pytorch/issues/122816 +# if local_rank == -1: +# # local rank not set, this usually happens in single-node +# # setting, where we can use rank as local rank +# if distributed_init_method == "env://": +# # local_rank = int(os.getenv('LOCAL_RANK', 0)) +# local_rank = dist.get_rank() +# print(f"init_distributed_environment 里面 local_rank {local_rank}") +# else: +# local_rank = rank +# global _WORLD +# if _WORLD is None: +# ranks = list(range(torch.distributed.get_world_size())) +# _WORLD = init_world_group(ranks, local_rank, backend) +# print(f"_WORLD 初始化") +# else: +# assert ( +# _WORLD.world_size == torch.distributed.get_world_size() +# ), "world group already initialized with a different world size" +# print(f"_WORLD 没有 初始化") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + if parallel_mode not in [ + "tensor", + "sequence", + "classifier_free_guidance", + ]: + logging.error(f"parallel_mode {parallel_mode} is not supported") + if parallel_mode == "sequence": # ulysses + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( # cfg + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + if not dist.is_initialized(): + logging.error("dist is not initialized") + world_size: int = dist.get_world_size() + backend = backend + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + * tensor_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x " + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x " + f"tensor_parallel_degree " + f"({tensor_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + classifier_free_guidance_degree, + "tp-sp-cfg", + ) + + global _CFG + if _CFG is not None: + logging.error("classifier_free_guidance group is already initialized") + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + if _SP is not None: + logging.error("sequence parallel group is already initialized") + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=world_size + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logging.warning("Model parallel is not initialized, initializing...") + init_distributed_environment( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + backend='hccl', + ) + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py new file mode 100644 index 00000000..c53ae68f --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py @@ -0,0 +1,152 @@ +from typing import List +import logging + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + if not ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ): + logging.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = tp * sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + "tp": self.tp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i, _ in enumerate(rank_group): + rank_group[i] += self.rank_offset + return ranks \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py new file mode 100644 index 00000000..3d5c4b3b --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,964 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import QwenImageLoraLoaderMixin +from diffusers.models import AutoencoderKLQwenImage + +# from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from ais_bench.benchmark.models.local_models.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput + +from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + init_distributed_environment, + initialize_model_parallel, + get_sequence_parallel_rank, + get_sp_group +) + +#------------------ljf------------------- +import os +import torch_npu +#------------------------------------- + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +#------------------ljf--------------- +USE_NPU = False +if torch.npu.is_available(): + USE_NPU = True + + +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) +#----------------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage_edit_plus.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + # print(f"device {device}, ljf 随机生成latents latents {latents}") + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, # ljf None + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if get_classifier_free_guidance_world_size() == 2: + if get_classifier_free_guidance_rank() == 0: + with self.transformer.cache_context("uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(noise_pred, separate_tensors=True) + + comb_pred = noise_pred_uncond + true_cfg_scale * (noise_pred_text - noise_pred_uncond) + + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) # 修正代码 + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + #------------ljf 原始代码--------------- + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + #---------------------------------------------------- + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000..0e67901d --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,563 @@ +# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, is_scipy_available, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") + + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float): + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + if not is_timesteps_provided: + timesteps = sigmas * self.config.num_train_timesteps + else: + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.config.stochastic_sampling: + print("ljf 进入采样器,涉及随机") + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + print("ljf 进入采样器,无随机") + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py new file mode 100644 index 00000000..52b2a002 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py @@ -0,0 +1,792 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +# from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput + +from diffusers.models.attention import AttentionMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm + +#------------ljf-------- +import torch_npu +from mindiesd import attention_forward +import os +ROPE_FUSE = bool(int(os.environ.get('ROPE_FUSE', 0))) +ADALN_FUSE = bool(int(os.environ.get('ADALN_FUSE', 0))) +#--------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + scale = (1 + scale.unsqueeze(1)) + shift = shift.unsqueeze(1) + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps), gate.unsqueeze(1) + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + if not ROPE_FUSE: #----------------- ljf -------------------- + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + else: + cos = freqs_cis.real + sin = freqs_cis.imag + seqlen = cos.shape[0] + + cos = cos.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + sin = sin.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + + x_out = torch_npu.npu_rotary_mul(x, cos, sin, 'interleave') + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + #-----------ljf------------------- + # if not torch.compiler.is_compiling(): + # if rope_key not in self.rope_cache: + # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + # video_freq = self.rope_cache[rope_key] + # else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + + + # print("ljf img_query ", img_query) + # exit() + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + # joint_hidden_states = dispatch_attention_fn( + # joint_query, + # joint_key, + # joint_value, + # attn_mask=attention_mask, + # dropout_p=0.0, + # is_causal=False, + # backend=self._attention_backend, + # parallel_config=self._parallel_config, + # ) + #--------------------ljf------------------------ + joint_hidden_states = attention_forward(joint_query, joint_key, joint_value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + #--------------------------------------------- + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +@maybe_allow_in_graph +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------ + if not ADALN_FUSE: + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------ + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + #--------------ljf------------------- + if not ADALN_FUSE: + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm2 = AdaLayerNorm(dim, eps=eps) + #-------------------------------- + + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------- + if not ADALN_FUSE: + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------------- + + # Text doesn't need separate attention - it's handled by img_attn joint computation + #---------------------------ljf-------------- + if not ADALN_FUSE: + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm2 = AdaLayerNorm(dim, eps=eps) + #---------------------------------------- + + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + # encoder_hidden_states: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + txt_pad_len = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + #------------------ljf------------------ + if not ADALN_FUSE: + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + else: + img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1) + #---------------------------------------- + + # Process text stream - norm1 + modulation + #----------------------ljf--------------- + if not ADALN_FUSE: + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + else: + txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) + #---------------------------------- + + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + # txt_pad_len = txt_pad_len, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + # ljf (B, S_txt_split , H*D_head), (B, S_img_split , H*D_head) + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + #-----------------------ljf----------- + if not ADALN_FUSE: + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + else: + img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2) + #--------------------------------------- + + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + #----------------ljf------------------------ + if not ADALN_FUSE: + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + else: + txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2) + #-------------------------------- + + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + # return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states + + +class QwenImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + # _cp_plan = { + # "": { + # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + # }, + # "pos_embed": { + # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # }, + # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + # } + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + #-----------------ljf------------- + self.cache_cond = None + self.cache_uncond = None + #------------------------------- + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + return_dict: bool = True, + use_cache: bool = False, #---------------ljf------------ + if_cond: bool = True, #-------------------ljf------------------ + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + ) + + else: + #--------------------ljf----------- + if not use_cache: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + if if_cond: + hidden_states, encoder_hidden_states = self.cache_cond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + hidden_states, encoder_hidden_states = self.cache_uncond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + #----------------------------------- + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file From ce167ed1a6a505086d4921417a3065309560428e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 13 Feb 2026 15:15:23 +0800 Subject: [PATCH 06/59] add qwen image edit dep --- .../local_models/qwenimage_edit/__init__.py | 0 .../local_models/qwenimage_edit/attn_layer.py | 201 ++++ .../qwenimage_edit/distributed/__init__.py | 0 .../qwenimage_edit/distributed/all_to_all.py | 156 +++ .../distributed/group_coordinator.py | 640 ++++++++++++ .../distributed/parallel_mgr.py | 404 ++++++++ .../qwenimage_edit/distributed/utils.py | 152 +++ .../pipeline_qwenimage_edit_plus.py | 964 ++++++++++++++++++ .../scheduling_flow_match_euler_discrete.py | 563 ++++++++++ .../qwenimage_edit/transformer_qwenimage.py | 792 ++++++++++++++ 10 files changed, 3872 insertions(+) create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py new file mode 100644 index 00000000..2d0e58e7 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py @@ -0,0 +1,201 @@ +import torch +from torch import Tensor +import torch_npu + +import torch.distributed as dist +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") + + +import math +import os +from typing import Any + +from mindiesd import attention_forward + + + +# from yunchang.comm.all_to_all import SeqAllToAll4D +# from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D +import logging + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group +) + + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + + +class xFuserLongContextAttention_new4(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + use_sync: bool = False, + attn_type: AttnType = AttnType.FA, + attn_processor: torch.nn.Module = None, + q_descale=None, + k_descale=None, + v_descale=None, + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + attn_type: AttnType = AttnType.FA, the attention type supported inside long context attention, including "TORCH", "FA", "FA3", "SAGE_FP16", "SAGE_FP8" + attn_processor: nn.Module = None, the attention processor can be passed in to replace the attention processor if attn_type is do not support it. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + use_sync=use_sync, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + + # 校验:仅"basic"类型的环形实现支持KV缓存 + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + + self.attn_processor = attn_processor + + @torch.compiler.disable + def forward( + self, + attn, + query: Tensor, # [B, S_image/ulysses_size, H, D] + key: Tensor, + value: Tensor, + *, + joint_tensor_query=None, # [B, S_text, H, D] + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + txt_pad_len = 0 + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + + + sp_world_size = get_sequence_parallel_world_size() # USP + sp_rank = get_sequence_parallel_rank() + + + + joint_tensor_query = torch.chunk(joint_tensor_query, sp_world_size, dim=2)[sp_rank] # [B, S_text, H, D] --> [B, S_text, H/ulysses_size, D] + joint_tensor_key = torch.chunk(joint_tensor_key, sp_world_size, dim=2)[sp_rank] + joint_tensor_value = torch.chunk(joint_tensor_value, sp_world_size, dim=2)[sp_rank] + + + + # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) + # scatter 2, gather 1 + if self.use_pack_qkv: + # (3*bs, seq_len/N, head_cnt, head_size) + qkv = torch.cat([query, key, value]).contiguous() + # (3*bs, seq_len, head_cnt/N, head_size) + qkv = SeqAllToAll4D.apply( + self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, + ) + qkv = torch.chunk(qkv, 3, dim=0) + query_layer, key_layer, value_layer = qkv + + else: + # 非打包模式:分别对Q/K/V进行通信拆分 + query_layer = SeqAllToAll4D.apply( + self.ulysses_pg, query, self.scatter_idx, self.gather_idx , # [B, S_image/ulysses_size, H, D] --> [B, S_image, H/ulysses_size, D] + ) + key_layer = SeqAllToAll4D.apply( + self.ulysses_pg, key, self.scatter_idx, self.gather_idx, + ) + value_layer = SeqAllToAll4D.apply( + self.ulysses_pg, value, self.scatter_idx, self.gather_idx, + ) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([joint_tensor_query, query_layer], dim=1) # (B, S_txt + S_img, H/ulysses_size, D_head) + joint_key = torch.cat([joint_tensor_key, key_layer], dim=1) + joint_value = torch.cat([joint_tensor_value, value_layer], dim=1) + + + out = attention_forward( + joint_query, + joint_key, + joint_value, + opt_mode="manual", + op_type="fused_attn_score", + layout="BNSD" + ) + + if type(out) == tuple: + context_layer, _, _ = out + else: + context_layer = out + + txt_seq_len = joint_tensor_query.shape[1] + + text_out = context_layer[:, :txt_seq_len, :, :].contiguous() # 强制连续 + image_out = context_layer[:, txt_seq_len:, :, :].contiguous() + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + image_out = SeqAllToAll4D.apply( + self.ulysses_pg, image_out, self.gather_idx, self.scatter_idx # [B, S_image, H/ulysses_size, D] --> [B, S_image/ulysses_size, H, D] + ) + + text_out = get_sp_group().all_gather(text_out, dim=2) # (B, S_txt , H/ulysses_size, D_head) --> (B, S_txt , H, D_head) + + output = torch.cat([text_out, image_out], dim=1) + # out e.g., [s/p::h] + return output + diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py new file mode 100644 index 00000000..2ffea2f1 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +import torch.distributed as dist + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + # 分支 1:scatter_idx=2 且 gather_idx=1(Ulysses 并行的 “拆分多头” 场景),按「多头维度(dim2)」拆分张量,同时将「序列维度(dim1)」重组为完整长度。 + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + comm_output = comm_output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + return comm_output + + return getter + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + # 分支 2:scatter_idx=1 且 gather_idx=2(Ulysses 并行的 “合并多头” 场景),与分支 1 相反,按「序列维度(dim1)」拆分张量,同时将「多头维度(dim2)」重组为完整多头数。 + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + comm_output = comm_output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + return comm_output + + return getter + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply( + ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync + ), + None, + None, + None, + ) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py new file mode 100644 index 00000000..c48d22a6 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py @@ -0,0 +1,640 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import logging + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if "%" in key: + logging.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + + # 原代码(导致超时) + # cpu_group = torch.distributed.new_group(ranks, backend="gloo") + + # 修改后(使用HCCL后端) + cpu_group = torch.distributed.new_group(ranks, backend="hccl") # 适配昇腾环境 + + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False, async_op: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + + # All-gather. + if async_op: + current_input_size = input_size.copy() # 复制列表 + current_world_size = world_size + current_dim = dim + current_separate_tensors = separate_tensors + comm = torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group, async_op=async_op + ) + + def getter(): + comm.wait() + nonlocal output_tensor # 声明为非局部变量 + + if current_dim != 0: + # 使用捕获的变量,而不是外部变量 + temp_size = current_input_size + temp_size[0] //= current_world_size + output_tensor = output_tensor.reshape([current_world_size] + temp_size) + output_tensor = output_tensor.movedim(0, current_dim) + + if current_separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(current_world_size) + ] + return tensor_list + else: + current_input_size[current_dim] = current_input_size[current_dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(current_input_size) + return output_tensor + + return getter + else: + torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_group = kwargs.get("ulysses_group", None) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + + self.ring_group = kwargs.get("ring_group", None) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py new file mode 100644 index 00000000..0b6ef343 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py @@ -0,0 +1,404 @@ +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +import logging +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +#--------- ljf ------------------- +import torch +import torch.distributed +try: + import torch_musa + from torch_musa.core.device import set_device, device_count +except ModuleNotFoundError: + pass +#--------------------------- + +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + + +@dataclass +class ParallelConfig: + tp_degree: int = 1 + sp_degree: int = 1 + ulysses_degree: int = 1 + ring_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: + logging.error( + "tp_degree * sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): + logging.error("world_size must be divisible by tp_degree * sp_degree * cfg_degree") + + +# * QUERY +def get_world_group() -> GroupCoordinator: + if _WORLD is None: + logging.error("world group is not initialized") + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + if _SP is None: + logging.error("pipeline model parallel group is not initialized") + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + if _CFG is None: + logging.error("classifier_free_guidance parallel group is not initialized") + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +# wan2.1 的 +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logging.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + if distributed_init_method is None: + logging.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('LOCAL_RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + if not _WORLD.world_size == dist.get_world_size(): + logging.error("world group already initialized with a different world size") + + +# def init_distributed_environment( +# world_size: int = -1, +# rank: int = -1, +# distributed_init_method: str = "env://", +# local_rank: int = -1, +# backend: str = "hccl", +# ): +# logging.debug( +# "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", +# world_size, +# rank, +# local_rank, +# distributed_init_method, +# backend, +# ) +# if not torch.distributed.is_initialized(): +# assert distributed_init_method is not None, ( +# "distributed_init_method must be provided when initializing " +# "distributed environment" +# ) +# # this backend is used for WORLD +# torch.distributed.init_process_group( +# backend=backend, +# init_method=distributed_init_method, +# world_size=world_size, +# rank=rank, +# ) +# set_device(torch.distributed.get_rank() % device_count()) +# # set the local rank +# # local_rank is not available in torch ProcessGroup, +# # see https://github.com/pytorch/pytorch/issues/122816 +# if local_rank == -1: +# # local rank not set, this usually happens in single-node +# # setting, where we can use rank as local rank +# if distributed_init_method == "env://": +# # local_rank = int(os.getenv('LOCAL_RANK', 0)) +# local_rank = dist.get_rank() +# print(f"init_distributed_environment 里面 local_rank {local_rank}") +# else: +# local_rank = rank +# global _WORLD +# if _WORLD is None: +# ranks = list(range(torch.distributed.get_world_size())) +# _WORLD = init_world_group(ranks, local_rank, backend) +# print(f"_WORLD 初始化") +# else: +# assert ( +# _WORLD.world_size == torch.distributed.get_world_size() +# ), "world group already initialized with a different world size" +# print(f"_WORLD 没有 初始化") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + if parallel_mode not in [ + "tensor", + "sequence", + "classifier_free_guidance", + ]: + logging.error(f"parallel_mode {parallel_mode} is not supported") + if parallel_mode == "sequence": # ulysses + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( # cfg + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + if not dist.is_initialized(): + logging.error("dist is not initialized") + world_size: int = dist.get_world_size() + backend = backend + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + * tensor_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x " + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x " + f"tensor_parallel_degree " + f"({tensor_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + classifier_free_guidance_degree, + "tp-sp-cfg", + ) + + global _CFG + if _CFG is not None: + logging.error("classifier_free_guidance group is already initialized") + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + if _SP is not None: + logging.error("sequence parallel group is already initialized") + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=world_size + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logging.warning("Model parallel is not initialized, initializing...") + init_distributed_environment( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + backend='hccl', + ) + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py new file mode 100644 index 00000000..c53ae68f --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py @@ -0,0 +1,152 @@ +from typing import List +import logging + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + if not ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ): + logging.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = tp * sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + "tp": self.tp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i, _ in enumerate(rank_group): + rank_group[i] += self.rank_offset + return ranks \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py new file mode 100644 index 00000000..3d5c4b3b --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,964 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import QwenImageLoraLoaderMixin +from diffusers.models import AutoencoderKLQwenImage + +# from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from ais_bench.benchmark.models.local_models.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput + +from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + init_distributed_environment, + initialize_model_parallel, + get_sequence_parallel_rank, + get_sp_group +) + +#------------------ljf------------------- +import os +import torch_npu +#------------------------------------- + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +#------------------ljf--------------- +USE_NPU = False +if torch.npu.is_available(): + USE_NPU = True + + +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) +#----------------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage_edit_plus.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + # print(f"device {device}, ljf 随机生成latents latents {latents}") + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, # ljf None + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if get_classifier_free_guidance_world_size() == 2: + if get_classifier_free_guidance_rank() == 0: + with self.transformer.cache_context("uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(noise_pred, separate_tensors=True) + + comb_pred = noise_pred_uncond + true_cfg_scale * (noise_pred_text - noise_pred_uncond) + + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) # 修正代码 + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + #------------ljf 原始代码--------------- + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + #---------------------------------------------------- + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000..0e67901d --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,563 @@ +# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, is_scipy_available, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") + + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float): + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + if not is_timesteps_provided: + timesteps = sigmas * self.config.num_train_timesteps + else: + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.config.stochastic_sampling: + print("ljf 进入采样器,涉及随机") + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + print("ljf 进入采样器,无随机") + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py new file mode 100644 index 00000000..52b2a002 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py @@ -0,0 +1,792 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +# from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput + +from diffusers.models.attention import AttentionMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm + +#------------ljf-------- +import torch_npu +from mindiesd import attention_forward +import os +ROPE_FUSE = bool(int(os.environ.get('ROPE_FUSE', 0))) +ADALN_FUSE = bool(int(os.environ.get('ADALN_FUSE', 0))) +#--------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + scale = (1 + scale.unsqueeze(1)) + shift = shift.unsqueeze(1) + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps), gate.unsqueeze(1) + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + if not ROPE_FUSE: #----------------- ljf -------------------- + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + else: + cos = freqs_cis.real + sin = freqs_cis.imag + seqlen = cos.shape[0] + + cos = cos.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + sin = sin.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + + x_out = torch_npu.npu_rotary_mul(x, cos, sin, 'interleave') + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + #-----------ljf------------------- + # if not torch.compiler.is_compiling(): + # if rope_key not in self.rope_cache: + # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + # video_freq = self.rope_cache[rope_key] + # else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + + + # print("ljf img_query ", img_query) + # exit() + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + # joint_hidden_states = dispatch_attention_fn( + # joint_query, + # joint_key, + # joint_value, + # attn_mask=attention_mask, + # dropout_p=0.0, + # is_causal=False, + # backend=self._attention_backend, + # parallel_config=self._parallel_config, + # ) + #--------------------ljf------------------------ + joint_hidden_states = attention_forward(joint_query, joint_key, joint_value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + #--------------------------------------------- + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +@maybe_allow_in_graph +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------ + if not ADALN_FUSE: + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------ + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + #--------------ljf------------------- + if not ADALN_FUSE: + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm2 = AdaLayerNorm(dim, eps=eps) + #-------------------------------- + + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------- + if not ADALN_FUSE: + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------------- + + # Text doesn't need separate attention - it's handled by img_attn joint computation + #---------------------------ljf-------------- + if not ADALN_FUSE: + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm2 = AdaLayerNorm(dim, eps=eps) + #---------------------------------------- + + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + # encoder_hidden_states: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + txt_pad_len = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + #------------------ljf------------------ + if not ADALN_FUSE: + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + else: + img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1) + #---------------------------------------- + + # Process text stream - norm1 + modulation + #----------------------ljf--------------- + if not ADALN_FUSE: + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + else: + txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) + #---------------------------------- + + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + # txt_pad_len = txt_pad_len, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + # ljf (B, S_txt_split , H*D_head), (B, S_img_split , H*D_head) + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + #-----------------------ljf----------- + if not ADALN_FUSE: + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + else: + img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2) + #--------------------------------------- + + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + #----------------ljf------------------------ + if not ADALN_FUSE: + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + else: + txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2) + #-------------------------------- + + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + # return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states + + +class QwenImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + # _cp_plan = { + # "": { + # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + # }, + # "pos_embed": { + # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # }, + # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + # } + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + #-----------------ljf------------- + self.cache_cond = None + self.cache_uncond = None + #------------------------------- + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + return_dict: bool = True, + use_cache: bool = False, #---------------ljf------------ + if_cond: bool = True, #-------------------ljf------------------ + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + ) + + else: + #--------------------ljf----------- + if not use_cache: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + if if_cond: + hidden_states, encoder_hidden_states = self.cache_cond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + hidden_states, encoder_hidden_states = self.cache_uncond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + #----------------------------------- + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file From 898385c703e2d372b9ecf04a51556eb46e889b4e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 14 Feb 2026 10:25:27 +0800 Subject: [PATCH 07/59] llm eval --- .../gedit/gedit_gen_0_shot_llmjudge.py | 110 ++++++++++++++++++ ais_bench/benchmark/datasets/g_edit.py | 4 +- .../benchmark/datasets/utils/llm_judge.py | 39 +++++++ 3 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py new file mode 100644 index 00000000..c6525445 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py @@ -0,0 +1,110 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_prompt_template.icl_prompt_template_mm import MMPromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content +from ais_bench.benchmark.datasets.g_edit import ( + GEditDataset, + GEditJDGDataset, +) +from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator + + +gedit_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='task_type' +) + + +gedit_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role="HUMAN", prompt_mm={ + "text": {"type": "text", "text": "{question}"}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=LMMGenInferencer) +) + +GRADER_TEMPLATE = """ +RULES: + +Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. +The objective is to evaluate how successfully the editing instruction has been executed in the second image. + +Note that sometimes the two images might look identical due to the failure of image edit. + +From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +Editing instruction: {question} +""".strip() + +gedit_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "model_answer", "image"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + abbr="judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="localhost", + host_port=8080, + url="", + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), + ), + judge_dataset_type=GEditJDGDataset, + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + "text": {"type": "text", "text": GRADER_TEMPLATE}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, # origin graph + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{prediction}"}}, # edited graph + }) + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=LMMGenInferencer), +) + +gedit_eval_cfg = dict( + evaluator=dict(type=LLMJudgeCorrectEvaluator), + pred_postprocessor=dict(type=get_a_or_b), +) + +gedit_datasets = [ + dict( + abbr="gedit", + type=GEditDataset, + path="ais_bench/datasets/gedit/gedit.jsonl", + reader_cfg=gedit_reader_cfg, + infer_cfg=gedit_infer_cfg, + judge_infer_cfg=gedit_judge_infer_cfg, + eval_cfg=gedit_eval_cfg, + ) +] diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index 9d9224b6..22310706 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -5,7 +5,7 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.openicl import BaseEvaluator from ais_bench.benchmark.datasets.utils.datasets import get_data_path -from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset +from ais_bench.benchmark.datasets.utils.llm_judge import LMMImgJDGDataset from ais_bench.benchmark.utils.image_process import pil_to_base64 from PIL import Image from tqdm import tqdm @@ -89,6 +89,6 @@ def process_example_to_dataset(example): return concatenate_datasets(processed_datasets) @LOAD_DATASET.register_module() -class GEditJDGDataset(LLMJudgeDataset): +class GEditJDGDataset(LMMImgJDGDataset): def _get_dataset_class(self): return GEditDataset \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index 8b6b18f0..45772634 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,5 +1,8 @@ import re import os +import base64 +from io import BytesIO +from PIL import Image from ais_bench.benchmark.utils.logging import AISLogger from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, @@ -32,6 +35,42 @@ def _load_from_predictions(self, prediction_path: str): preds.sort(key=lambda x: x.get('id',0)) return preds +class LMMImgJDGDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """从prediction中拿到对应图片相对路径,将这个路径的图片加载并转换为Base64字符串. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + + # 遍历预测结果,加载图片并转换为Base64字符串 + for pred in preds: + # 假设pred中包含图片相对路径 + image_path = pred.get('prediction', '') + if image_path and os.path.exists(image_path): + try: + # 加载图片 + with Image.open(image_path) as img: + # 转换为RGB格式 + img = img.convert('RGB') + # 保存到BytesIO + buffered = BytesIO() + img.save(buffered, format="PNG") + # 转换为Base64字符串 + img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + # 更新pred中的image字段为Base64字符串 + pred['prediction'] = img_base64 + except Exception as e: + logger.error(f"Failed to load image {image_path}: {e}") + + preds.sort(key=lambda x: x.get('id', 0)) + return preds + @ICL_EVALUATORS.register_module() class LLMJudgeCorrectEvaluator(BaseEvaluator): From af1240f5289ed9554d0df2810cb537c385600b1b Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 14 Feb 2026 10:36:30 +0800 Subject: [PATCH 08/59] base judge ds class generalize --- ais_bench/benchmark/datasets/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 243a52b0..57a34359 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -131,7 +131,7 @@ def load(self, predictions_path: str, **kwargs): dataset_list = [] for item in predictions: item_dict = dataset_content[int(item["id"])] - item_dict["model_answer"] = item["prediction"] + self._modify_dataset_item(item_dict, item) item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold dataset_list.append(item_dict) elif isinstance(dataset_content, DatasetDict): @@ -139,7 +139,7 @@ def load(self, predictions_path: str, **kwargs): for key in dataset_content: for item in predictions: item_dict = dataset_content[key][int(item["id"])] - item_dict["model_answer"] = item["prediction"] + self._modify_dataset_item(item_dict, item) item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold dataset_list.append(item_dict) else: @@ -155,6 +155,9 @@ def _load_from_predictions(self, prediction_path: str) -> Dict: def _get_dataset_class(self): return BaseDataset + def _modify_dataset_item(self, dataset_item, pred_item): + dataset_item["model_answer"] = pred_item["prediction"] + def _init_org_datasets_instance( self, reader_cfg: Optional[Dict] = {}, From e18209ce3dc7f9d2051295eda25798828c192cc9 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 24 Feb 2026 09:36:40 +0800 Subject: [PATCH 09/59] llm eval --- ais_bench/benchmark/cli/argument_parser.py | 2 +- ais_bench/benchmark/cli/workers.py | 2 + .../gedit/gedit_gen_0_shot_llmjudge.py | 10 +-- ais_bench/benchmark/datasets/g_edit.py | 4 +- .../benchmark/datasets/utils/llm_judge.py | 38 --------- .../benchmark/datasets/utils/lmm_judge.py | 79 +++++++++++++++++++ .../multi_device_run_qwen_image_edit.py | 2 +- 7 files changed, 90 insertions(+), 47 deletions(-) create mode 100644 ais_bench/benchmark/datasets/utils/lmm_judge.py diff --git a/ais_bench/benchmark/cli/argument_parser.py b/ais_bench/benchmark/cli/argument_parser.py index 38d79851..38a02e7b 100644 --- a/ais_bench/benchmark/cli/argument_parser.py +++ b/ais_bench/benchmark/cli/argument_parser.py @@ -61,7 +61,7 @@ def _base_parser(self): help='Running mode. Choose "perf" for performance evaluation, "infer" to run inference only, ' '"eval" to evaluate existing inference results, or "viz" to visualize the results. ' 'The default mode is "all", which runs all steps.', - choices=['all', 'infer', 'eval', 'viz', 'perf', 'perf_viz'], + choices=['all', 'infer', 'eval', 'viz', 'perf', 'perf_viz', 'judge', 'infer_judge'], default='all', type=str ) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 689762a0..3aca0325 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -381,6 +381,8 @@ def do_work(self, cfg: ConfigDict) -> int: WORK_FLOW = dict( all=[Infer, JudgeInfer, Eval, AccViz], infer=[Infer], + judge=[JudgeInfer], + infer_judge=[Infer, JudgeInfer], eval=[JudgeInfer, Eval, AccViz], viz=[AccViz], perf=[Infer, PerfViz], diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py index c6525445..5229f3d4 100644 --- a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py @@ -63,11 +63,11 @@ use_timestamp=False, retry=2, api_key="", - host_ip="localhost", - host_port=8080, + host_ip="192.168.9.123", + host_port=5103, url="", max_out_len=512, - batch_size=1, + batch_size=16, trust_remote_code=False, generation_kwargs=dict( temperature=0.01, @@ -89,7 +89,7 @@ ), ), retriever=dict(type=ZeroRetriever), - inferencer=dict(type=LMMGenInferencer), + inferencer=dict(type=GenInferencer), ) gedit_eval_cfg = dict( @@ -101,7 +101,7 @@ dict( abbr="gedit", type=GEditDataset, - path="ais_bench/datasets/gedit/gedit.jsonl", + path="ais_bench/datasets/GEdit-Bench", reader_cfg=gedit_reader_cfg, infer_cfg=gedit_infer_cfg, judge_infer_cfg=gedit_judge_infer_cfg, diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index 22310706..89951bdd 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -5,7 +5,7 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.openicl import BaseEvaluator from ais_bench.benchmark.datasets.utils.datasets import get_data_path -from ais_bench.benchmark.datasets.utils.llm_judge import LMMImgJDGDataset +from ais_bench.benchmark.datasets.utils.lmm_judge import LMMImgJDGDataset from ais_bench.benchmark.utils.image_process import pil_to_base64 from PIL import Image from tqdm import tqdm @@ -13,7 +13,7 @@ from ais_bench.benchmark.datasets.base import BaseDataset from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START -GEDIT_COUNT = 10 +GEDIT_COUNT = 1 class GEditEvaluator(BaseEvaluator): def score(self, predictions, references): diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index 45772634..07c8e0df 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,7 +1,5 @@ import re import os -import base64 -from io import BytesIO from PIL import Image from ais_bench.benchmark.utils.logging import AISLogger @@ -35,42 +33,6 @@ def _load_from_predictions(self, prediction_path: str): preds.sort(key=lambda x: x.get('id',0)) return preds -class LMMImgJDGDataset(BaseJDGDataset): - def _load_from_predictions(self, prediction_path: str): - """从prediction中拿到对应图片相对路径,将这个路径的图片加载并转换为Base64字符串. - - Args: - prediction_path (str): The path to the prediction file. - - Returns: - Dataset: The merged dataset with predictions. - """ - if os.path.exists(prediction_path): - preds = load_jsonl(prediction_path) - - # 遍历预测结果,加载图片并转换为Base64字符串 - for pred in preds: - # 假设pred中包含图片相对路径 - image_path = pred.get('prediction', '') - if image_path and os.path.exists(image_path): - try: - # 加载图片 - with Image.open(image_path) as img: - # 转换为RGB格式 - img = img.convert('RGB') - # 保存到BytesIO - buffered = BytesIO() - img.save(buffered, format="PNG") - # 转换为Base64字符串 - img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') - # 更新pred中的image字段为Base64字符串 - pred['prediction'] = img_base64 - except Exception as e: - logger.error(f"Failed to load image {image_path}: {e}") - - preds.sort(key=lambda x: x.get('id', 0)) - return preds - @ICL_EVALUATORS.register_module() class LLMJudgeCorrectEvaluator(BaseEvaluator): diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py new file mode 100644 index 00000000..0e910fb0 --- /dev/null +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -0,0 +1,79 @@ +import re +import os +import base64 +import concurrent.futures +from io import BytesIO +from PIL import Image +from tqdm import tqdm + +from ais_bench.benchmark.datasets.needlebench_v2 import origin +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl +from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError +from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES +logger = AISLogger() + +class LMMImgJDGDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """从prediction中拿到对应图片相对路径,将这个路径的图片加载并转换为Base64字符串. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if not os.path.exists(prediction_path): + return [] + + preds = load_jsonl(prediction_path) + base_path = os.path.dirname(prediction_path) + + # 定义图片处理函数 + def process_image(pred_item): + image_path = os.path.join(base_path, pred_item.get('prediction', '')) + if image_path and os.path.exists(image_path): + try: + # 加载图片 + with Image.open(image_path) as img: + # 转换为RGB格式 + img = img.convert('RGB') + # 保存到BytesIO + buffered = BytesIO() + img.save(buffered, format="PNG") + # 转换为Base64字符串 + img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + # 更新pred中的image字段为Base64字符串 + pred_item['prediction'] = img_base64 + except Exception as e: + raise AISBenchRuntimeError(DSET_CODES.UNKNOWN_ERROR, f"Failed to process image {image_path}: {e}") + return pred_item + + # 使用并行处理加速图片处理 + max_workers = min(8, os.cpu_count()) # 根据CPU核心数调整 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 使用tqdm显示进度 + processed_preds = list(tqdm( + executor.map(process_image, preds), + total=len(preds), + desc="Processing images", + unit="image" + )) + + processed_preds.sort(key=lambda x: x.get('id', 0)) + return processed_preds + + def _modify_dataset_item(self, dataset_item, pred_item): + for item in dataset_item["content"].split(AIS_CONTENT_TAG): + if item.startswith(AIS_TEXT_START): + question = item.replace(AIS_TEXT_START, "") + elif item.startswith(AIS_IMAGE_START): + org_image_url = item.replace(AIS_IMAGE_START, "") + dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG \ No newline at end of file diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py index 467ed0b3..f1f881e2 100644 --- a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -3,7 +3,7 @@ with read_base(): from ais_bench.benchmark.configs.models.lmm_models.qwen_image_edit import models as qwen_image_edit_models from ais_bench.benchmark.configs.summarizers.example import summarizer - from ais_bench.benchmark.configs.datasets.gedit.gedit_gen import gedit_datasets + from ais_bench.benchmark.configs.datasets.gedit.gedit_gen_0_shot_llmjudge import gedit_datasets device_list = [0, 1, 2, 3] From c543451ccdaf04b841baf942b2666d3fa5f7717a Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 24 Feb 2026 09:33:57 +0800 Subject: [PATCH 10/59] fix judge worker bug --- ais_bench/benchmark/cli/workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 3aca0325..0fcbe6ab 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -116,7 +116,7 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): class JudgeInfer(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: def get_task_type() -> str: - if cfg["models"][0]["attr"] == "service": + if task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service": return get_config_type(OpenICLApiInferTask) else: return get_config_type(OpenICLInferTask) From 5724e5810022ea851232adaadfce5a14dae2d454 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 24 Feb 2026 09:40:26 +0800 Subject: [PATCH 11/59] fix judge worker bug --- ais_bench/benchmark/cli/workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 0fcbe6ab..9672c804 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -116,7 +116,7 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): class JudgeInfer(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: def get_task_type() -> str: - if task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service": + if cfg["datasets"][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service": return get_config_type(OpenICLApiInferTask) else: return get_config_type(OpenICLInferTask) From 77c8fc5f2bd77389049f2624d9abe956b89388e0 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 10:33:57 +0800 Subject: [PATCH 12/59] lmm eval fix --- .../gedit/gedit_gen_0_shot_llmjudge.py | 175 +++++++++++------- ais_bench/benchmark/datasets/g_edit.py | 11 +- .../benchmark/datasets/utils/lmm_judge.py | 75 +++++++- .../icl_inferencer/icl_lmm_gen_inferencer.py | 8 - .../gen_inferencer_output_handler.py | 3 + .../lmm_gen_inferencer_output_handler.py | 4 +- ais_bench/benchmark/utils/prompt/prompt.py | 8 +- 7 files changed, 196 insertions(+), 88 deletions(-) diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py index 5229f3d4..75081615 100644 --- a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py @@ -7,9 +7,56 @@ from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content from ais_bench.benchmark.datasets.g_edit import ( GEditDataset, - GEditJDGDataset, + GEditSCJDGDataset, + GEditPQJDGDataset, ) -from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator +from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list, LMMJudgeImageEditEvaluator + +SC_GRADER_TEMPLATE = """ +RULES: + +Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. +The objective is to evaluate how successfully the editing instruction has been executed in the second image. + +Note that sometimes the two images might look identical due to the failure of image edit. + +From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +Editing instruction: {question} +""".strip() + +PQ_GRADER_TEMPLATE = """ +RULES: + +The image is an AI-generated image. +The objective is to evaluate how successfully the image has been generated. + +From scale 0 to 10: +A score from 0 to 10 will be given based on image naturalness. +( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. +) +A second score from 0 to 10 will rate the image artifacts. +( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. +) +Put the score in a list such that output score = [naturalness, artifacts] +""".strip() + +JDG_DATASETS_CLASS_MAP = { + "SC": GEditSCJDGDataset, + "PQ": GEditPQJDGDataset, +} + +JDG_TEMPLATE_MAP = { + "SC": SC_GRADER_TEMPLATE, + "PQ": PQ_GRADER_TEMPLATE, +} gedit_reader_cfg = dict( @@ -34,77 +81,63 @@ inferencer=dict(type=LMMGenInferencer) ) -GRADER_TEMPLATE = """ -RULES: +gedit_datasets = [] -Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. -The objective is to evaluate how successfully the editing instruction has been executed in the second image. - -Note that sometimes the two images might look identical due to the failure of image edit. - -From scale 0 to 10: -A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) -A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) -Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. - -Editing instruction: {question} -""".strip() - -gedit_judge_infer_cfg = dict( - judge_reader_cfg = dict(input_columns=["question", "model_answer", "image"], output_column="model_pred_uuid"), - judge_model=dict( - attr="service", - type=VLLMCustomAPIChat, - abbr="judge", # Be added after dataset abbr - path="", - model="", - stream=True, - request_rate=0, - use_timestamp=False, - retry=2, - api_key="", - host_ip="192.168.9.123", - host_port=5103, - url="", - max_out_len=512, - batch_size=16, - trust_remote_code=False, - generation_kwargs=dict( - temperature=0.01, - ignore_eos=False, +for metric in ["SC", "PQ"]: + gedit_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "model_answer", "image"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + abbr=f"{metric}_judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="192.168.9.123", + host_port=5103, + url="", + max_out_len=512, + batch_size=16, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), ), - pred_postprocessor=dict(type=extract_non_reasoning_content), - ), - judge_dataset_type=GEditJDGDataset, - prompt_template=dict( - type=MMPromptTemplate, - template=dict( - round=[ - dict(role='HUMAN', prompt_mm={ - "text": {"type": "text", "text": GRADER_TEMPLATE}, - "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, # origin graph - "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{prediction}"}}, # edited graph - }) - ], + judge_dataset_type=JDG_DATASETS_CLASS_MAP[metric], + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + "text": {"type": "text", "text": JDG_TEMPLATE_MAP[metric]}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ], + ), ), - ), - retriever=dict(type=ZeroRetriever), - inferencer=dict(type=GenInferencer), -) - -gedit_eval_cfg = dict( - evaluator=dict(type=LLMJudgeCorrectEvaluator), - pred_postprocessor=dict(type=get_a_or_b), -) + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) -gedit_datasets = [ - dict( - abbr="gedit", - type=GEditDataset, - path="ais_bench/datasets/GEdit-Bench", - reader_cfg=gedit_reader_cfg, - infer_cfg=gedit_infer_cfg, - judge_infer_cfg=gedit_judge_infer_cfg, - eval_cfg=gedit_eval_cfg, + gedit_eval_cfg = dict( + evaluator=dict(type=LMMJudgeImageEditEvaluator, metric=metric), + pred_postprocessor=dict(type=get_lmm_point_list), ) -] + + gedit_datasets.append( + dict( + abbr=f"gedit", + type=GEditDataset, + path="ais_bench/datasets/GEdit-Bench", + reader_cfg=gedit_reader_cfg, + infer_cfg=gedit_infer_cfg, + judge_infer_cfg=gedit_judge_infer_cfg, + eval_cfg=gedit_eval_cfg, + ) + ) \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index 89951bdd..f762f654 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -5,7 +5,7 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.openicl import BaseEvaluator from ais_bench.benchmark.datasets.utils.datasets import get_data_path -from ais_bench.benchmark.datasets.utils.lmm_judge import LMMImgJDGDataset +from ais_bench.benchmark.datasets.utils.lmm_judge import ImgSCJDGDataset, ImgPQJDGDataset from ais_bench.benchmark.utils.image_process import pil_to_base64 from PIL import Image from tqdm import tqdm @@ -88,7 +88,14 @@ def process_example_to_dataset(example): # 合并所有 Dataset return concatenate_datasets(processed_datasets) + +@LOAD_DATASET.register_module() +class GEditSCJDGDataset(ImgSCJDGDataset): + def _get_dataset_class(self): + return GEditDataset + + @LOAD_DATASET.register_module() -class GEditJDGDataset(LMMImgJDGDataset): +class GEditPQJDGDataset(ImgPQJDGDataset): def _get_dataset_class(self): return GEditDataset \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index 0e910fb0..8cd32282 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -1,5 +1,6 @@ import re import os +import json import base64 import concurrent.futures from io import BytesIO @@ -18,6 +19,14 @@ from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES logger = AISLogger() + +@TEXT_POSTPROCESSORS.register_module("get_a_or_b") +def get_lmm_point_list(pred: str) -> str: + """从模型回复中提取列表的字符串""" + match = re.search(r'\[.*?\]', pred) + return match.group(0) if match else '' + + class LMMImgJDGDataset(BaseJDGDataset): def _load_from_predictions(self, prediction_path: str): """从prediction中拿到对应图片相对路径,将这个路径的图片加载并转换为Base64字符串. @@ -74,6 +83,70 @@ def _modify_dataset_item(self, dataset_item, pred_item): question = item.replace(AIS_TEXT_START, "") elif item.startswith(AIS_IMAGE_START): org_image_url = item.replace(AIS_IMAGE_START, "") + self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG \ - + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG \ No newline at end of file + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG + + +class ImgSCJDGDataset(LMMImgJDGDataset): + def _modify_dataset_item(self, dataset_item, pred_item): + for item in dataset_item["content"].split(AIS_CONTENT_TAG): + if item.startswith(AIS_TEXT_START): + question = item.replace(AIS_TEXT_START, "") + elif item.startswith(AIS_IMAGE_START): + org_image_url = item.replace(AIS_IMAGE_START, "") + self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") + dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG + + +class ImgPQJDGDataset(LMMImgJDGDataset): + def _modify_dataset_item(self, dataset_item, pred_item): + for item in dataset_item["content"].split(AIS_CONTENT_TAG): + if item.startswith(AIS_TEXT_START): + question = item.replace(AIS_TEXT_START, "") + elif item.startswith(AIS_IMAGE_START): + org_image_url = item.replace(AIS_IMAGE_START, "") + self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") + dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ + + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG + + +POINT_KEY_LIST_MAP = { + "SC": ["editing success", "over editing"], + "PQ": ["naturalness", "artifacts"] +} + + +@ICL_EVALUATORS.register_module() +class LMMJudgeImageEditEvaluator(BaseEvaluator): + def __init__(self, metric: str = "SC"): + self.metric = metric + self.point_key_list = POINT_KEY_LIST_MAP[metric] + super().__init__() + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + + # 将get_lmm_point_list获取的字符串格式的list转换成list格式 + if not all(isinstance(pred, str) for pred in predictions): + return {'error': 'preds must be strings'} + predictions = [json.loads(get_lmm_point_list(pred)) for pred in predictions] + + total_score = 0 + count = 0 + details = [] + for pred, ref in zip(predictions, references): + if len(pred) != len(self.point_key_list): + detail = {'pred': pred, 'org_uuid': ref, 'eval_success': False, 'failed reason': 'prediction format error, length of point list not equal to point key list'} + else: + detail = {'pred': {key: score for key, score in zip(self.point_key_list, pred)}, 'org_uuid': ref, 'eval_success': True} + count += 1 + if detail['eval_success']: + total_score += min(pred) + details.append(detail) + result = {f"{self.metric}": total_score / count, 'details': details} + return result \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py index 48604ba9..2da7ddc8 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py @@ -1,11 +1,3 @@ -''' -Author: SJTUyh yh_silence@alumni.sjtu.edu.cn -Date: 2026-02-11 16:38:01 -LastEditors: SJTUyh yh_silence@alumni.sjtu.edu.cn -LastEditTime: 2026-02-12 18:39:02 -FilePath: \benchmark\ais_bench\benchmark\openicl\icl_inferencer\icl_lmm_gen_inferencer.py -Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE -''' import uuid from typing import List, Optional diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 726f841c..705340c8 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -41,6 +41,9 @@ def get_prediction_result( Returns: dict: Prediction result """ + for item in input[0]['prompt']: + if item.get('image_url') and len(item['image_url']['url']) > 256: + item['image_url']['url'] = item['image_url']['url'][:256] + " ..." result_data = { "success": ( output.success if isinstance(output, Output) else True diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py index 70a05336..cdc4a21b 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -48,8 +48,8 @@ def get_prediction_result( if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) for item in input[0]['prompt']: - if item.get('image_url'): - item['image_url']['url'] = item['image_url']['url'][:256] + if item.get('image_url') and len(item['image_url']['url']) > 256: + item['image_url']['url'] = item['image_url']['url'][:256] + " ..." result_data = { "success": ( output.success if isinstance(output, LMMOutput) else True diff --git a/ais_bench/benchmark/utils/prompt/prompt.py b/ais_bench/benchmark/utils/prompt/prompt.py index 6f0d1aa2..67351a26 100644 --- a/ais_bench/benchmark/utils/prompt/prompt.py +++ b/ais_bench/benchmark/utils/prompt/prompt.py @@ -157,13 +157,13 @@ def format_mm(self, **kwargs) -> PromptList: if item.startswith(AIS_TEXT_START): question = item.replace(AIS_TEXT_START, "") question = {"question": question} - text_content = mm['text'].copy() + text_content = mm['text'].deepcopy() text_content['text'] = safe_format(text_content['text'], **question) contents.append(text_content) elif item.startswith(AIS_IMAGE_START): image = item.replace(AIS_IMAGE_START, "") image = {"image": image} - image_content = mm['image'].copy() + image_content = mm['image'].deepcopy() if isinstance(image_content['image_url'], dict): image_content['image_url']['url'] = safe_format(image_content['image_url']['url'], **image) else: @@ -172,7 +172,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_VIDEO_START): video = item.replace(AIS_VIDEO_START, "") video = {"video": video} - video_content = mm['video'].copy() + video_content = mm['video'].deepcopy() if isinstance(video_content['video_url'], dict): video_content['video_url']['url'] = safe_format(video_content['video_url']['url'], **video) else: @@ -181,7 +181,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_AUDIO_START): audio = item.replace(AIS_AUDIO_START, "") audio = {"audio": audio} - audio_content = mm['audio'].copy() + audio_content = mm['audio'].deepcopy() if isinstance(audio_content['audio_url'], dict): audio_content['audio_url']['url'] = safe_format(audio_content['audio_url']['url'], **audio) else: From 68a15968793cd7b5a383a647e06a0fbb9cbb124c Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 10:49:26 +0800 Subject: [PATCH 13/59] support multi judge dataset tasks --- ais_bench/benchmark/cli/workers.py | 37 ++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 9672c804..cebe2588 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -143,6 +143,7 @@ def get_task_type() -> str: def do_work(self, cfg: ConfigDict): partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner) logger.info("Starting inference tasks...") + self._cfg_pre_process(cfg) tasks = partitioner(cfg) # delete the tasks without judge_infer_cfg @@ -190,6 +191,16 @@ def _merge_datasets(self, tasks): new_tasks.append(new_task) return new_tasks + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + self.org_dataset_abbrs = {} + for i, dataset in enumerate(cfg.datasets): + if dataset.get("judge_infer_cfg"): + org_dataset_abbr = cfg.datasets[i]["abbr"] + new_dataset_abbr = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' + cfg.datasets[i]["abbr"] = new_dataset_abbr + self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + return cfg + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # update parameters to correct sub cfg if hasattr(cfg, "attack"): @@ -199,10 +210,9 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # update judge cfgs to model cfgs and data for task in tasks: - task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') if not osp.exists(task["datasets"][0][0]["predictions_path"]): raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' model_abbr = task["models"][0]["abbr"] task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") task["models"][0]["abbr"] = model_abbr @@ -215,7 +225,7 @@ def _result_post_process(self, tasks, cfg: ConfigDict): for task in tasks: model_org_prediction_path = task["datasets"][0][0]["predictions_path"] model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)} - judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') judge_preds: list = load_jsonl(judge_org_prediction_path) for i, pred in enumerate(judge_preds): uuid = pred["gold"] @@ -252,6 +262,8 @@ def update_cfg(self, cfg: ConfigDict) -> None: def do_work(self, cfg: ConfigDict): partitioner = PARTITIONERS.build(cfg.eval.partitioner) logger.info("Starting evaluation tasks...") + self._cfg_pre_process(cfg) + tasks = partitioner(cfg) # Update tasks cfg before run @@ -267,24 +279,29 @@ def do_work(self, cfg: ConfigDict): self._result_post_process(tasks, cfg) logger.info("Evaluation tasks completed.") + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + self.org_dataset_abbrs = {} + for i, dataset in enumerate(cfg.datasets): + if dataset.get("judge_infer_cfg"): + org_dataset_abbr = cfg.datasets[i]["abbr"] + new_dataset_abbr = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' + cfg.datasets[i]["abbr"] = new_dataset_abbr + self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + return cfg + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # Replace default model config to judge model config - self.judge_result_paths = {} for task in tasks: if task["datasets"][0][0].get("judge_infer_cfg"): - new_dataset_abbr = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' - org_dataset_abbr = task["datasets"][0][0]["abbr"] - self.judge_result_paths[new_dataset_abbr] = org_dataset_abbr - task["datasets"][0][0]["abbr"] = new_dataset_abbr task["datasets"][0][0].pop("judge_infer_cfg") def _result_post_process(self, tasks, cfg: ConfigDict): # Copy judge infer result to normal name for task in tasks: - if task["datasets"][0][0]["abbr"] in self.judge_result_paths.keys(): + if task["datasets"][0][0]["abbr"] in self.org_dataset_abbrs.keys(): cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') - final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.judge_result_paths[task["datasets"][0][0]["abbr"]]}.jsonl') + final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') if os.path.exists(final_org_results_path): os.remove(final_org_results_path) From 04fa57c141968b895f332c6cfde06c0355bf9473 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 11:19:01 +0800 Subject: [PATCH 14/59] support multi judge dataset tasks --- ais_bench/benchmark/cli/workers.py | 32 ++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index cebe2588..8dd497f3 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -193,12 +193,18 @@ def _merge_datasets(self, tasks): def _cfg_pre_process(self, cfg: ConfigDict) -> None: self.org_dataset_abbrs = {} - for i, dataset in enumerate(cfg.datasets): - if dataset.get("judge_infer_cfg"): - org_dataset_abbr = cfg.datasets[i]["abbr"] - new_dataset_abbr = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' - cfg.datasets[i]["abbr"] = new_dataset_abbr + def change_judge_dataset_abbr(item): + if item.get("judge_infer_cfg"): + org_dataset_abbr = item["abbr"] + new_dataset_abbr = f'{item["abbr"]}-{item["judge_infer_cfg"]["judge_model"]["abbr"]}' + item["abbr"] = new_dataset_abbr self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + if cfg.get('model_dataset_combinations', None) is not None: + for item in cfg.model_dataset_combinations: + for dataset in item["datasets"]: + change_judge_dataset_abbr(dataset) + for dataset in cfg.datasets: + change_judge_dataset_abbr(dataset) return cfg def _update_tasks_cfg(self, tasks, cfg: ConfigDict): @@ -281,12 +287,18 @@ def do_work(self, cfg: ConfigDict): def _cfg_pre_process(self, cfg: ConfigDict) -> None: self.org_dataset_abbrs = {} - for i, dataset in enumerate(cfg.datasets): - if dataset.get("judge_infer_cfg"): - org_dataset_abbr = cfg.datasets[i]["abbr"] - new_dataset_abbr = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' - cfg.datasets[i]["abbr"] = new_dataset_abbr + def change_eval_dataset_abbr(item): + if item.get("judge_infer_cfg"): + org_dataset_abbr = item["abbr"] + new_dataset_abbr = f'{item["abbr"]}-{item["judge_infer_cfg"]["judge_model"]["abbr"]}' + item["abbr"] = new_dataset_abbr self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + if cfg.get('model_dataset_combinations', None) is not None: + for item in cfg.model_dataset_combinations: + for dataset in item["datasets"]: + change_eval_dataset_abbr(dataset) + for dataset in cfg.datasets: + change_eval_dataset_abbr(dataset) return cfg def _update_tasks_cfg(self, tasks, cfg: ConfigDict): From 7355dcb3c40bc3a56a946e2b5e246a23ee7e7200 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 11:27:39 +0800 Subject: [PATCH 15/59] fix custom config --- .../multi_device_run_qwen_image_edit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py index f1f881e2..33431000 100644 --- a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -5,7 +5,7 @@ from ais_bench.benchmark.configs.summarizers.example import summarizer from ais_bench.benchmark.configs.datasets.gedit.gedit_gen_0_shot_llmjudge import gedit_datasets -device_list = [0, 1, 2, 3] +device_list = [0] # [0, 1, 2, 3] datasets = [] models = [] @@ -18,14 +18,17 @@ model_config['device_kwargs']['device_id'] = i models.append(model_config) - dataset_config = {k: v for k, v in gedit_datasets[0].items()} - dataset_config['abbr'] = f"{dataset_config['abbr']}-{i}" - dataset_config['split_count'] = len(device_list) - dataset_config['split_index'] = i - datasets.append(dataset_config) + dataset_configs = [] + for dataset in gedit_datasets: + dataset_config = {k: v for k, v in dataset.items()} + dataset_config['abbr'] = f"{dataset_config['abbr']}-{i}" + dataset_config['split_count'] = len(device_list) + dataset_config['split_index'] = i + dataset_configs.append(dataset_config) + datasets.extend(dataset_configs) # 关键:为每个设备创建一个独立的 model-dataset 组合 model_dataset_combinations.append({ 'models': [model_config], # 只包含当前模型 - 'datasets': [dataset_config] # 只包含当前数据集 + 'datasets': dataset_configs # 只包含当前数据集 }) \ No newline at end of file From 963a3b4e5a79e497ee78c91bff729314c0da1b01 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 11:55:05 +0800 Subject: [PATCH 16/59] support multi judge dataset tasks --- ais_bench/benchmark/cli/workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 8dd497f3..1299d9fa 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -231,7 +231,7 @@ def _result_post_process(self, tasks, cfg: ConfigDict): for task in tasks: model_org_prediction_path = task["datasets"][0][0]["predictions_path"] model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)} - judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') + judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') judge_preds: list = load_jsonl(judge_org_prediction_path) for i, pred in enumerate(judge_preds): uuid = pred["gold"] From f42cdb7a15c8636f753dd27dc3a25da0a1b203cf Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 12:01:30 +0800 Subject: [PATCH 17/59] support multi judge dataset tasks --- ais_bench/benchmark/cli/workers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 1299d9fa..a64ed52b 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -426,4 +426,5 @@ def __init__(self, cfg, workflow) -> None: def execute(self) -> None: for worker in self.workflow: - worker.do_work(self.cfg) + cfg = copy.deepcopy(self.cfg) + worker.do_work(cfg) From 9d26569dfbb84594a0e86ae792848ffe8f86600b Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 15:10:44 +0800 Subject: [PATCH 18/59] judge fix --- ais_bench/benchmark/datasets/utils/lmm_judge.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index 8cd32282..ce3c3f91 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -77,17 +77,6 @@ def process_image(pred_item): processed_preds.sort(key=lambda x: x.get('id', 0)) return processed_preds - def _modify_dataset_item(self, dataset_item, pred_item): - for item in dataset_item["content"].split(AIS_CONTENT_TAG): - if item.startswith(AIS_TEXT_START): - question = item.replace(AIS_TEXT_START, "") - elif item.startswith(AIS_IMAGE_START): - org_image_url = item.replace(AIS_IMAGE_START, "") - self.logger.debug(f"org_image_url: {org_image_url[:64]} \n pred_image_url: {pred_item['prediction'][:64]}") - dataset_item["content"] = AIS_TEXT_START + question + AIS_CONTENT_TAG \ - + AIS_IMAGE_START + org_image_url + AIS_CONTENT_TAG \ - + AIS_IMAGE_START + pred_item['prediction'] + AIS_CONTENT_TAG - class ImgSCJDGDataset(LMMImgJDGDataset): def _modify_dataset_item(self, dataset_item, pred_item): From 5c92fd79462fccadc54feef04e80270ef7665ca5 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 15:36:15 +0800 Subject: [PATCH 19/59] asnyc process predictions --- ais_bench/benchmark/datasets/base.py | 44 +++++++++++++++++++++------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 57a34359..d063d79a 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,5 +1,7 @@ from abc import abstractmethod from typing import List, Dict, Optional, Union, Type +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm from datasets import Dataset, DatasetDict from datasets.utils.logging import disable_progress_bar @@ -129,19 +131,33 @@ def load(self, predictions_path: str, **kwargs): # 为数据集添加 model_answer 列 if isinstance(dataset_content, Dataset): dataset_list = [] - for item in predictions: - item_dict = dataset_content[int(item["id"])] - self._modify_dataset_item(item_dict, item) - item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold - dataset_list.append(item_dict) + with ThreadPoolExecutor() as executor: + futures = [] + for item in predictions: + future = executor.submit(self._process_single_item, dataset_content, item) + futures.append(future) + + with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: + for future in as_completed(futures): + result = future.result() + dataset_list.append(result) + pbar.update(1) + pbar.refresh() elif isinstance(dataset_content, DatasetDict): dataset_list = [] - for key in dataset_content: - for item in predictions: - item_dict = dataset_content[key][int(item["id"])] - self._modify_dataset_item(item_dict, item) - item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold - dataset_list.append(item_dict) + with ThreadPoolExecutor() as executor: + futures = [] + for key in dataset_content: + for item in predictions: + future = executor.submit(self._process_single_item, dataset_content[key], item) + futures.append(future) + + with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: + for future in as_completed(futures): + result = future.result() + dataset_list.append(result) + pbar.update(1) + pbar.refresh() else: raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") @@ -158,6 +174,12 @@ def _get_dataset_class(self): def _modify_dataset_item(self, dataset_item, pred_item): dataset_item["model_answer"] = pred_item["prediction"] + def _process_single_item(self, dataset_content, pred_item): + item_dict = dataset_content[int(pred_item["id"])] + self._modify_dataset_item(item_dict, pred_item) + item_dict["model_pred_uuid"] = pred_item["uuid"] + return item_dict + def _init_org_datasets_instance( self, reader_cfg: Optional[Dict] = {}, From dcfc50e19273c2ee826812760bfb006d979bdd5b Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 17:04:25 +0800 Subject: [PATCH 20/59] fast trans to dataset --- ais_bench/benchmark/datasets/base.py | 35 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index d063d79a..6e7704af 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -129,8 +129,11 @@ def load(self, predictions_path: str, **kwargs): predictions: list = self._load_from_predictions(predictions_path) # 为数据集添加 model_answer 列 + batch_size = 10 # 批处理大小,可以根据实际情况调整 + dataset_batches = [] + current_batch = [] + if isinstance(dataset_content, Dataset): - dataset_list = [] with ThreadPoolExecutor() as executor: futures = [] for item in predictions: @@ -140,11 +143,16 @@ def load(self, predictions_path: str, **kwargs): with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: for future in as_completed(futures): result = future.result() - dataset_list.append(result) + current_batch.append(result) + + # 当批次达到指定大小时,转换为Dataset并添加到批次列表 + if len(current_batch) >= batch_size: + dataset_batches.append(Dataset.from_list(current_batch)) + current_batch = [] + pbar.update(1) pbar.refresh() elif isinstance(dataset_content, DatasetDict): - dataset_list = [] with ThreadPoolExecutor() as executor: futures = [] for key in dataset_content: @@ -155,13 +163,30 @@ def load(self, predictions_path: str, **kwargs): with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: for future in as_completed(futures): result = future.result() - dataset_list.append(result) + current_batch.append(result) + + # 当批次达到指定大小时,转换为Dataset并添加到批次列表 + if len(current_batch) >= batch_size: + dataset_batches.append(Dataset.from_list(current_batch)) + current_batch = [] + pbar.update(1) pbar.refresh() else: raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") - return Dataset.from_list(dataset_list) + # 处理最后一个不完整的批次 + if current_batch: + dataset_batches.append(Dataset.from_list(current_batch)) + + # 合并所有批次的Dataset + if dataset_batches: + if len(dataset_batches) == 1: + return dataset_batches[0] + else: + return Dataset.concatenate_datasets(dataset_batches) + else: + return Dataset.from_list([]) @abstractmethod def _load_from_predictions(self, prediction_path: str) -> Dict: From 880bc8cb85d0b346b99aaa4e5bf3f21c9879e7cb Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 25 Feb 2026 17:28:31 +0800 Subject: [PATCH 21/59] fast trans to dataset --- ais_bench/benchmark/datasets/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 6e7704af..a22f99e8 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -184,7 +184,8 @@ def load(self, predictions_path: str, **kwargs): if len(dataset_batches) == 1: return dataset_batches[0] else: - return Dataset.concatenate_datasets(dataset_batches) + from datasets import concatenate_datasets + return concatenate_datasets(dataset_batches) else: return Dataset.from_list([]) From 6d13d2a3be85ccde9aa7984c4ce3383bbedac42e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 09:52:09 +0800 Subject: [PATCH 22/59] add a gedit display tool --- ais_bench/benchmark/cli/config_manager.py | 3 +- .../benchmark/datasets/utils/lmm_judge.py | 4 +- .../gedit/display_results.py | 174 ++++++++++++++++++ 3 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 tools/dataset_processors/gedit/display_results.py diff --git a/ais_bench/benchmark/cli/config_manager.py b/ais_bench/benchmark/cli/config_manager.py index f7c3ab5f..dde76527 100644 --- a/ais_bench/benchmark/cli/config_manager.py +++ b/ais_bench/benchmark/cli/config_manager.py @@ -1,4 +1,3 @@ - import os import os.path as osp import tabulate @@ -104,7 +103,7 @@ def load_config(self, workflow): self._update_cfg_of_workflow(workflow) self._dump_and_reload_config() return self.cfg - + def _fill_dataset_configs(self): for dataset_cfg in self.cfg["datasets"]: fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg) diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index ce3c3f91..e7d88602 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -20,10 +20,10 @@ logger = AISLogger() -@TEXT_POSTPROCESSORS.register_module("get_a_or_b") +@TEXT_POSTPROCESSORS.register_module("get_lmm_point_list") def get_lmm_point_list(pred: str) -> str: """从模型回复中提取列表的字符串""" - match = re.search(r'\[.*?\]', pred) + match = re.search(r'\[\s*\d+(?:\s*,\s*\d+)*\s*\]', pred) return match.group(0) if match else '' diff --git a/tools/dataset_processors/gedit/display_results.py b/tools/dataset_processors/gedit/display_results.py new file mode 100644 index 00000000..4c5d7388 --- /dev/null +++ b/tools/dataset_processors/gedit/display_results.py @@ -0,0 +1,174 @@ +import os +import math +import copy +import argparse +import json +import csv +import tabulate + +from ais_bench.benchmark.configs.datasets.needlebench_v2.needlebench_v2_4k.needlebench_v2_multi_reasoning_4k import language +from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.cli.config_manager import CustomConfigChecker +from ais_bench.benchmark.utils.logging.logger import AISLogger +from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list +from mmengine.config import Config + +logger = AISLogger(__name__) + + +def load_config(config_path: str) -> Config: + """加载配置文件并进行校验""" + if not os.path.exists(config_path): + raise ParameterValueError(f"Config path: {config_path} is not exist!") + try: + config = Config.fromfile(config_path, format_python_code=False) + except BaseException as e: + raise RuntimeError(f"Fail to load config {config_path}, failed reason: {e}") + CustomConfigChecker(config, config_path).check() + return config + +class GEditEvalResultParser: + def __init__(self, args): + self.config = load_config(args.config_path) + self.output_dir = args.timestamp_path + self.paths_map = dict( + org_pred_path = [], + sc_judge_pred_path = [], + pq_judge_pred_path = [], + ) + for comb in self.config["model_dataset_combinations"]: + model_abbr = comb["models"][0]["abbr"] + dataset_org_abbr = comb["datasets"][0]["abbr"] + dataset_sc_abbr = f'{comb["datasets"][0]["abbr"]}-{comb["datasets"][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' + dataset_pq_abbr = f'{comb["datasets"][1]["abbr"]}-{comb["datasets"][1]["judge_infer_cfg"]["judge_model"]["abbr"]}' + self.paths_map["org_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_org_abbr}.jsonl")) + self.paths_map["sc_judge_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_sc_abbr}.jsonl")) + self.paths_map["pq_judge_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_pq_abbr}.jsonl")) + + def parse_results(self): + logger.info(f"Start parse judge infer result from: {self.output_dir}") + org_pred_data_list = self._load_and_merge_jsonl("org_pred_path") + org_pred_data_dict = {item["uuid"]: item for item in org_pred_data_list} + sc_judge_pred_data_list = self._load_and_merge_jsonl("sc_judge_pred_path") + sc_judge_pred_data_dict = {item["gold"]: item for item in sc_judge_pred_data_list} + pq_judge_pred_data_list = self._load_and_merge_jsonl("pq_judge_pred_path") + pq_judge_pred_data_dict = {item["gold"]: item for item in pq_judge_pred_data_list} + + self.all_data_results = {} + + for uuid in org_pred_data_dict.keys(): + id = org_pred_data_dict[uuid]["id"] + output_img_path = org_pred_data_dict[uuid]["prediction"] + sc_point = self._calc_meta_point(sc_judge_pred_data_dict[uuid]["prediction"]) + pq_point = self._calc_meta_point(pq_judge_pred_data_dict[uuid]["prediction"]) + o_point = math.sqrt(sc_point * pq_point) + question, case_language = self._get_question_and_language(org_pred_data_dict[uuid]["origin_prompt"][0]["prompt"]) + self.all_data_results[id] = { + "uuid": uuid, + "question": question, + "language": case_language, + "output_img_path": output_img_path, + "SC_point": sc_point, + "PQ_point": pq_point, + "O_point": o_point, + } + logger.info(f"Finish parsing results") + + def dump_result_csv(self): + save_path = os.path.join(self.output_dir, "results") + logger.info(f"Start dumping details ......") + if not os.path.exists(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, "gedit_gathered_result.csv"), "w", encoding="utf-8", newline="") as f: + fieldnames = ["id", "uuid", "question", "language", "output_img_path", "SC_point", "PQ_point", "O_point"] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for id in self.all_data_results.keys(): + writer.writerow(self.all_data_results[id]) + logger.info(f"Finish dumping csv to: {os.path.join(save_path, 'gedit_gathered_result.csv')}") + + def display_results(self): + evaluate_result_list = [] + + lang_count = {"zh": 0, "en": 0} + for id in self.all_data_results.keys(): + lang_count[self.all_data_results[id]["language"]] += 1 + + for lang in ["zh", "en"]: + sc_point_sum = 0 + pq_point_sum = 0 + o_point_sum = 0 + count = 0 + for id in self.all_data_results.keys(): + if self.all_data_results[id]["language"] == lang: + sc_point_sum += self.all_data_results[id]["SC_point"] + pq_point_sum += self.all_data_results[id]["PQ_point"] + o_point_sum += self.all_data_results[id]["O_point"] + count += 1 + if count > 0: + evaluate_result_list.append(copy.deepcopy([lang, sc_point_sum / count, pq_point_sum / count, o_point_sum / count])) + + sc_point_sum = 0 + pq_point_sum = 0 + o_point_sum = 0 + count = len(self.all_data_results) + for id in self.all_data_results.keys(): + sc_point_sum += self.all_data_results[id]["SC_point"] + pq_point_sum += self.all_data_results[id]["PQ_point"] + o_point_sum += self.all_data_results[id]["O_point"] + evaluate_result_list.append(copy.deepcopy(["all case", sc_point_sum / count, pq_point_sum / count, o_point_sum / count])) + + print(tabulate.tabulate(evaluate_result_list, + headers=["language", "SC_point", "PQ_point", "O_point"], + floatfmt=".4f")) + + def _calc_meta_point(self, pred): + results_list = get_lmm_point_list(pred) + if not results_list: + return 0 + try: + point_list = json.loads(results_list) + except BaseException as e: + raise RuntimeError(f"Illegal prediction: {pred}") + return min(point_list) + + def _get_question_and_language(self, input_prompt): + question = "" + for item in input_prompt: + if item.get("type") == "text": + question = item.get("text") + # 判断question是否包含中文 + if any("\u4e00" <= char <= "\u9fff" for char in question): + return question, "zh" + else: + return question, "en" + + def _load_and_merge_jsonl(self, path_kind: "org_pred_path"): + merged_data = [] + start_index = 0 + for path in self.paths_map[path_kind]: + offset_index = copy.deepcopy(start_index) + with open(path, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + data["id"] = data["id"] + offset_index + start_index += 1 + merged_data.append(data) + return merged_data + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="显示gedit数据集的推理结果") + parser.add_argument("--config_path", help="配置文件路径") + parser.add_argument("--timestamp_path", help="结果时间戳路径") + + args = parser.parse_args() + eval_parser = GEditEvalResultParser(args) + eval_parser.parse_results() + eval_parser.dump_result_csv() + eval_parser.display_results() + + +if __name__ == "__main__": + main() From 6ecd383731a63c89cfaf16eeeb7a0cb79207439d Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 10:17:17 +0800 Subject: [PATCH 23/59] add task_state_manager to base dataset --- ais_bench/benchmark/datasets/base.py | 7 +++++++ ais_bench/benchmark/tasks/openicl_api_infer.py | 1 + ais_bench/benchmark/tasks/openicl_eval.py | 7 ++++--- ais_bench/benchmark/tasks/openicl_infer.py | 1 + 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index a22f99e8..dbf77b96 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -10,6 +10,7 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.tasks.base import TaskStateManager disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -19,9 +20,13 @@ def __init__(self, reader_cfg: Optional[Dict] = {}, k: Union[int, List[int]] = 1, n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, **kwargs): # Validate k and n parameters self.logger = AISLogger() + if task_state_manager is not None: + self._init_task_state_manager(task_state_manager) + max_k = max(k) if isinstance(k, List) else k if max_k > n: raise ParameterValueError( @@ -37,6 +42,8 @@ def __init__(self, self._init_reader(**reader_cfg) self.repeated_dataset(self.abbr, n) # this process will update self.dataset and self.reader.dataset + def _init_task_state_manager(self, task_state_manager): + self.task_state_manager = task_state_manager def _init_reader(self, **kwargs): self.reader = DatasetReader(self.dataset, **kwargs) diff --git a/ais_bench/benchmark/tasks/openicl_api_infer.py b/ais_bench/benchmark/tasks/openicl_api_infer.py index 442f0632..b38f88c8 100644 --- a/ais_bench/benchmark/tasks/openicl_api_infer.py +++ b/ais_bench/benchmark/tasks/openicl_api_infer.py @@ -164,6 +164,7 @@ def _get_data_list(self) -> tuple[List, List]: data_abbr = dataset_cfg["abbr"] cur_data_cache = finish_cache_data.get(data_abbr, {}) infer_cfg = dataset_cfg["infer_cfg"] + dataset_cfg["task_state_manager"] = self.task_state_manager dataset = build_dataset_from_cfg(dataset_cfg) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset diff --git a/ais_bench/benchmark/tasks/openicl_eval.py b/ais_bench/benchmark/tasks/openicl_eval.py index 6cbb3a3c..48d9bd42 100644 --- a/ais_bench/benchmark/tasks/openicl_eval.py +++ b/ais_bench/benchmark/tasks/openicl_eval.py @@ -71,7 +71,8 @@ def get_command(self, cfg_path, template): command = f'{python} {script_path} {cfg_path}' return template.format(task_cmd=command) - def run(self): + def run(self, task_state_manager: TaskStateManager): + self.task_state_manager = task_state_manager for dataset_cfg in self.dataset_cfgs: self.dataset_cfg = dataset_cfg # Load Dataset @@ -111,7 +112,7 @@ def _score(self): "k":k, "n":n }) - + self.dataset_cfg["task_state_manager"] = self.task_state_manager test_set = build_dataset_from_cfg(self.dataset_cfg).test # Postprocess dataset if necessary if 'dataset_postprocessor' in self.eval_cfg: @@ -515,7 +516,7 @@ def parse_args(): start_time = time.perf_counter() try: evaluator: OpenICLEvalTask = OpenICLEvalTask(cfg) - evaluator.run() + evaluator.run(task_state_manager) except Exception as e: task_state_manager.update_task_state({"status": "error"}) raise e diff --git a/ais_bench/benchmark/tasks/openicl_infer.py b/ais_bench/benchmark/tasks/openicl_infer.py index 5d20d67e..93a52a57 100644 --- a/ais_bench/benchmark/tasks/openicl_infer.py +++ b/ais_bench/benchmark/tasks/openicl_infer.py @@ -123,6 +123,7 @@ def _inference(self): retrievers = [] for dataset_cfg in self.dataset_cfgs: infer_cfg = dataset_cfg["infer_cfg"] + dataset_cfg["task_state_manager"] = self.task_state_manager dataset = build_dataset_from_cfg(dataset_cfg) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset From bcb61ec149a8e3d2b17463a7ec1ff3bdc471b119 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 10:55:09 +0800 Subject: [PATCH 24/59] add task_state_manager to base dataset --- ais_bench/benchmark/datasets/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index dbf77b96..9bddfb06 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from pickle import DICT from typing import List, Dict, Optional, Union, Type from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -48,6 +49,11 @@ def _init_task_state_manager(self, task_state_manager): def _init_reader(self, **kwargs): self.reader = DatasetReader(self.dataset, **kwargs) + def update_task_state(self, state: Dict): + if self.task_state_manager is not None: + self.task_state_manager.update(state) + else: + self.logger.warning("Task state manager is not initialized, cannot update task state") def repeated_dataset(self, abbr: str, n: int): # Create repeated indices in batches to avoid generating an oversized index list at once From fcee08ffb70d76b4c3018415445b58948cc7892f Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 11:45:16 +0800 Subject: [PATCH 25/59] add process bar to task state manager --- ais_bench/benchmark/datasets/g_edit.py | 21 ++++++++++++++++++- .../benchmark/datasets/utils/lmm_judge.py | 16 ++++++++++---- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index f762f654..b4413c03 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -31,6 +31,11 @@ class GEditDataset(BaseDataset): @staticmethod def load(path, use_raw=False, split_count=1, split_index=0, **kwargs): path = get_data_path(path) + self.update_task_state( + { + "state": "loading dataset", + } + ) dataset = load_from_disk(path) # 数据集切分:分成 split_count 份,取第 split_index 份 @@ -71,17 +76,31 @@ def process_example_to_dataset(example): with ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交所有任务 - with tqdm(total=len(dataset), desc=f"Submitting tasks split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: + with tqdm(total=len(dataset), desc=f"Convert GEdit dataset to base64, split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: futures = {} for i, example in enumerate(dataset): future = executor.submit(process_example_to_dataset, example) futures[future] = i + self.update_task_state( + { + "total_count": len(dataset), + "progress_description": f"Convert GEdit dataset to base64", + "finish_count": i, + } + ) submit_pbar.update(1) # 收集处理完成的 Dataset with tqdm(total=len(dataset), desc="Processing GEdit dataset", unit="example") as pbar: for future in as_completed(futures): idx = futures[future] + self.update_task_state( + { + "total_count": len(dataset), + "progress_description": f"Processing GEdit dataset", + "finish_count": idx, + } + ) processed_datasets[idx] = future.result() pbar.update(1) diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index e7d88602..8fd0967e 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -44,7 +44,15 @@ def _load_from_predictions(self, prediction_path: str): base_path = os.path.dirname(prediction_path) # 定义图片处理函数 - def process_image(pred_item): + def process_image(index, pred_item): + # 现在可以使用index来知道pred_item是preds中的第几个 + self.update_task_state( + { + "total_count": len(preds), + "progress_description": f"Convert prediction images to base64", + "finish_count": index, + } + ) image_path = os.path.join(base_path, pred_item.get('prediction', '')) if image_path and os.path.exists(image_path): try: @@ -60,7 +68,7 @@ def process_image(pred_item): # 更新pred中的image字段为Base64字符串 pred_item['prediction'] = img_base64 except Exception as e: - raise AISBenchRuntimeError(DSET_CODES.UNKNOWN_ERROR, f"Failed to process image {image_path}: {e}") + raise AISBenchRuntimeError(DSET_CODES.UNKNOWN_ERROR, f"Failed to process image {image_path} at index {index}: {e}") return pred_item # 使用并行处理加速图片处理 @@ -68,9 +76,9 @@ def process_image(pred_item): with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 使用tqdm显示进度 processed_preds = list(tqdm( - executor.map(process_image, preds), + executor.map(lambda x: process_image(x[0], x[1]), enumerate(preds)), total=len(preds), - desc="Processing images", + desc="Convert prediction images to base64", unit="image" )) From d4861fe544c53d5c03dabc2b3f439b55c22b3c3e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 11:52:55 +0800 Subject: [PATCH 26/59] add base jdg process bar to task state manager --- ais_bench/benchmark/datasets/base.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 9bddfb06..96af1a60 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -154,7 +154,7 @@ def load(self, predictions_path: str, **kwargs): futures.append(future) with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: - for future in as_completed(futures): + for i, future in enumerate(as_completed(futures)): result = future.result() current_batch.append(result) @@ -164,6 +164,13 @@ def load(self, predictions_path: str, **kwargs): current_batch = [] pbar.update(1) + self.update_task_state( + { + "total_count": len(futures), + "progress_description": "Infer progress", + "finish_count": i + 1, + } + ) pbar.refresh() elif isinstance(dataset_content, DatasetDict): with ThreadPoolExecutor() as executor: @@ -174,7 +181,7 @@ def load(self, predictions_path: str, **kwargs): futures.append(future) with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: - for future in as_completed(futures): + for i, future in enumerate(as_completed(futures)): result = future.result() current_batch.append(result) @@ -184,6 +191,13 @@ def load(self, predictions_path: str, **kwargs): current_batch = [] pbar.update(1) + self.update_task_state( + { + "total_count": len(futures), + "progress_description": "Infer progress", + "finish_count": i + 1, + } + ) pbar.refresh() else: raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") From d7be38bc1884534e621343d99fa62e3248e68c80 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 11:55:19 +0800 Subject: [PATCH 27/59] add lmm jdg process bar to task state manager --- ais_bench/benchmark/datasets/g_edit.py | 4 ++-- ais_bench/benchmark/datasets/utils/lmm_judge.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index b4413c03..d7c42fe9 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -85,7 +85,7 @@ def process_example_to_dataset(example): { "total_count": len(dataset), "progress_description": f"Convert GEdit dataset to base64", - "finish_count": i, + "finish_count": i + 1, } ) submit_pbar.update(1) @@ -98,7 +98,7 @@ def process_example_to_dataset(example): { "total_count": len(dataset), "progress_description": f"Processing GEdit dataset", - "finish_count": idx, + "finish_count": idx + 1, } ) processed_datasets[idx] = future.result() diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index 8fd0967e..f35ecf53 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -46,13 +46,6 @@ def _load_from_predictions(self, prediction_path: str): # 定义图片处理函数 def process_image(index, pred_item): # 现在可以使用index来知道pred_item是preds中的第几个 - self.update_task_state( - { - "total_count": len(preds), - "progress_description": f"Convert prediction images to base64", - "finish_count": index, - } - ) image_path = os.path.join(base_path, pred_item.get('prediction', '')) if image_path and os.path.exists(image_path): try: @@ -67,6 +60,13 @@ def process_image(index, pred_item): img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') # 更新pred中的image字段为Base64字符串 pred_item['prediction'] = img_base64 + self.update_task_state( + { + "total_count": len(preds), + "progress_description": f"Convert prediction images to base64", + "finish_count": index, + } + ) except Exception as e: raise AISBenchRuntimeError(DSET_CODES.UNKNOWN_ERROR, f"Failed to process image {image_path} at index {index}: {e}") return pred_item From d0279e88d647d136fced76b3f1c1b8c4ffc082e7 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 14:13:33 +0800 Subject: [PATCH 28/59] self task state manager in api infer task --- ais_bench/benchmark/tasks/openicl_api_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ais_bench/benchmark/tasks/openicl_api_infer.py b/ais_bench/benchmark/tasks/openicl_api_infer.py index b38f88c8..d187c7cb 100644 --- a/ais_bench/benchmark/tasks/openicl_api_infer.py +++ b/ais_bench/benchmark/tasks/openicl_api_infer.py @@ -474,6 +474,7 @@ def warm_up(self, data_list: List, task_state_manager: TaskStateManager): def run(self, task_state_manager: TaskStateManager): self.logger.info(f"Task [{task_abbr_from_cfg(self.cfg)}]") + self.task_state_manager = task_state_manager self.inferencer:BaseApiInferencer = ICL_INFERENCERS.build(self.inferencer_cfg) self.clean_failed_results() From b3291c06bdbb234b8c97546ab49d052b48024faf Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 14:19:05 +0800 Subject: [PATCH 29/59] load function from static to member --- ais_bench/benchmark/datasets/g_edit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index d7c42fe9..8a797db2 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -28,8 +28,7 @@ def score(self, predictions, references): @LOAD_DATASET.register_module() class GEditDataset(BaseDataset): - @staticmethod - def load(path, use_raw=False, split_count=1, split_index=0, **kwargs): + def load(self, path, use_raw=False, split_count=1, split_index=0, **kwargs): path = get_data_path(path) self.update_task_state( { From 5d66eb98c5359d01d44d368c180a001f4aea3cf7 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 14:23:20 +0800 Subject: [PATCH 30/59] update function fix --- ais_bench/benchmark/datasets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 96af1a60..263f8707 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -51,7 +51,7 @@ def _init_reader(self, **kwargs): def update_task_state(self, state: Dict): if self.task_state_manager is not None: - self.task_state_manager.update(state) + self.task_state_manager.update_task_state(state) else: self.logger.warning("Task state manager is not initialized, cannot update task state") From d714b030382f7155ff768a306015867e7d890d2f Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 15:15:52 +0800 Subject: [PATCH 31/59] task manager effect in jdg class --- ais_bench/benchmark/datasets/base.py | 8 +++++--- ais_bench/benchmark/tasks/openicl_api_infer.py | 3 +-- ais_bench/benchmark/tasks/openicl_eval.py | 3 +-- ais_bench/benchmark/tasks/openicl_infer.py | 3 +-- ais_bench/benchmark/utils/config/build.py | 5 ++++- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 263f8707..1368bbf9 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -130,9 +130,10 @@ def __init__(self, reader_cfg: Optional[Dict] = {}, k: Union[int, List[int]] = 1, n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, **kwargs): - self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, **kwargs) - super().__init__(reader_cfg, k, n, **kwargs) + self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, task_state_manager, **kwargs) + super().__init__(reader_cfg, k, n, task_state_manager, **kwargs) def load(self, predictions_path: str, **kwargs): @@ -238,7 +239,8 @@ def _init_org_datasets_instance( reader_cfg: Optional[Dict] = {}, k: Union[int, List[int]] = 1, n: int = 1, + task_state_manager: Optional[TaskStateManager] = None, **kwargs): dataset_class = self._get_dataset_class() - return dataset_class(reader_cfg, k, n, **kwargs) + return dataset_class(reader_cfg, k, n, task_state_manager, **kwargs) diff --git a/ais_bench/benchmark/tasks/openicl_api_infer.py b/ais_bench/benchmark/tasks/openicl_api_infer.py index d187c7cb..6092aa0e 100644 --- a/ais_bench/benchmark/tasks/openicl_api_infer.py +++ b/ais_bench/benchmark/tasks/openicl_api_infer.py @@ -164,8 +164,7 @@ def _get_data_list(self) -> tuple[List, List]: data_abbr = dataset_cfg["abbr"] cur_data_cache = finish_cache_data.get(data_abbr, {}) infer_cfg = dataset_cfg["infer_cfg"] - dataset_cfg["task_state_manager"] = self.task_state_manager - dataset = build_dataset_from_cfg(dataset_cfg) + dataset = build_dataset_from_cfg(dataset_cfg, task_state_manager=self.task_state_manager) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset retriever_cfg["prompt_template"] = infer_cfg.get("prompt_template", None) diff --git a/ais_bench/benchmark/tasks/openicl_eval.py b/ais_bench/benchmark/tasks/openicl_eval.py index 48d9bd42..d0f2ff60 100644 --- a/ais_bench/benchmark/tasks/openicl_eval.py +++ b/ais_bench/benchmark/tasks/openicl_eval.py @@ -112,8 +112,7 @@ def _score(self): "k":k, "n":n }) - self.dataset_cfg["task_state_manager"] = self.task_state_manager - test_set = build_dataset_from_cfg(self.dataset_cfg).test + test_set = build_dataset_from_cfg(self.dataset_cfg, task_state_manager=self.task_state_manager).test # Postprocess dataset if necessary if 'dataset_postprocessor' in self.eval_cfg: self.logger.debug(f"Dataset postprocessor: {self.eval_cfg['dataset_postprocessor']}") diff --git a/ais_bench/benchmark/tasks/openicl_infer.py b/ais_bench/benchmark/tasks/openicl_infer.py index 93a52a57..8937ca5d 100644 --- a/ais_bench/benchmark/tasks/openicl_infer.py +++ b/ais_bench/benchmark/tasks/openicl_infer.py @@ -123,8 +123,7 @@ def _inference(self): retrievers = [] for dataset_cfg in self.dataset_cfgs: infer_cfg = dataset_cfg["infer_cfg"] - dataset_cfg["task_state_manager"] = self.task_state_manager - dataset = build_dataset_from_cfg(dataset_cfg) + dataset = build_dataset_from_cfg(dataset_cfg, task_state_manager=self.task_state_manager) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset retriever_cfg["prompt_template"] = infer_cfg.get("prompt_template", None) diff --git a/ais_bench/benchmark/utils/config/build.py b/ais_bench/benchmark/utils/config/build.py index 2634fc86..bd7f33db 100644 --- a/ais_bench/benchmark/utils/config/build.py +++ b/ais_bench/benchmark/utils/config/build.py @@ -1,6 +1,7 @@ import os import copy import ipaddress +from typing import Any from mmengine.config import ConfigDict @@ -113,11 +114,13 @@ def check(condition, key, message): return errors -def build_dataset_from_cfg(dataset_cfg: ConfigDict): +def build_dataset_from_cfg(dataset_cfg: ConfigDict, task_state_manager: Any = None): logger.debug(f"Building dataset from config: type={dataset_cfg.get('type')} abbr={dataset_cfg.get('abbr')}") dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg.pop("infer_cfg", None) dataset_cfg.pop("eval_cfg", None) + if task_state_manager is not None: + dataset_cfg["task_state_manager"] = task_state_manager return LOAD_DATASET.build(dataset_cfg) From ace5d272d3c0302253d8379245e34a2af03db5c7 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 16:11:21 +0800 Subject: [PATCH 32/59] fix status --- ais_bench/benchmark/datasets/g_edit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index 8a797db2..d23e63ec 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -32,7 +32,7 @@ def load(self, path, use_raw=False, split_count=1, split_index=0, **kwargs): path = get_data_path(path) self.update_task_state( { - "state": "loading dataset", + "status": "loading dataset", } ) dataset = load_from_disk(path) From c00788acb74c3492448b66de427a25c60a3e1611 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 17:01:54 +0800 Subject: [PATCH 33/59] mv third party --- .../models/local_models/qwen_image_edit_mindie_sd.py | 4 ++-- .../qwenimage_edit => third_party/mindie_sd}/__init__.py | 0 .../mindie_sd/qwenimage_edit}/__init__.py | 0 .../mindie_sd}/qwenimage_edit/attn_layer.py | 0 .../mindie_sd/qwenimage_edit/distributed/__init__.py | 0 .../mindie_sd}/qwenimage_edit/distributed/all_to_all.py | 0 .../qwenimage_edit/distributed/group_coordinator.py | 0 .../mindie_sd}/qwenimage_edit/distributed/parallel_mgr.py | 0 .../mindie_sd}/qwenimage_edit/distributed/utils.py | 0 .../mindie_sd}/qwenimage_edit/pipeline_qwenimage_edit_plus.py | 0 .../qwenimage_edit/scheduling_flow_match_euler_discrete.py | 0 .../mindie_sd}/qwenimage_edit/transformer_qwenimage.py | 0 12 files changed, 2 insertions(+), 2 deletions(-) rename ais_bench/{benchmark/models/local_models/qwenimage_edit => third_party/mindie_sd}/__init__.py (100%) rename ais_bench/{benchmark/models/local_models/qwenimage_edit/distributed => third_party/mindie_sd/qwenimage_edit}/__init__.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/attn_layer.py (100%) create mode 100644 ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/__init__.py rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/all_to_all.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/group_coordinator.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/parallel_mgr.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/utils.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/pipeline_qwenimage_edit_plus.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/scheduling_flow_match_euler_discrete.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/transformer_qwenimage.py (100%) diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py index b4115533..505d6346 100644 --- a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -61,8 +61,8 @@ def decorator(func): # 导入 qwen_image_edit 相关模块 try: - from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel - from ais_bench.benchmark.models.local_models.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline + from ais_bench.third_party.mindie_sd.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel + from ais_bench.third_party.mindie_sd.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline from mindiesd import CacheConfig, CacheAgent except ImportError as e: raise ImportError(f"请确保 qwenimage_edit 模块在 Python 路径中: {e}") diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py b/ais_bench/third_party/mindie_sd/__init__.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py rename to ais_bench/third_party/mindie_sd/__init__.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/__init__.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/__init__.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py diff --git a/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/__init__.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/all_to_all.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/all_to_all.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/group_coordinator.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/group_coordinator.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/utils.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/utils.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/scheduling_flow_match_euler_discrete.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/scheduling_flow_match_euler_discrete.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/transformer_qwenimage.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/transformer_qwenimage.py From 2fd6ca877ab30dc06eefdb00b5cf34f146c71f2c Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 17:17:28 +0800 Subject: [PATCH 34/59] mv tool to inner --- ais_bench/tools/__init__.py | 0 ais_bench/tools/dataset_processors/__init__.py | 0 .../tools}/dataset_processors/gedit/display_results.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 ais_bench/tools/__init__.py create mode 100644 ais_bench/tools/dataset_processors/__init__.py rename {tools => ais_bench/tools}/dataset_processors/gedit/display_results.py (98%) diff --git a/ais_bench/tools/__init__.py b/ais_bench/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/tools/dataset_processors/__init__.py b/ais_bench/tools/dataset_processors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/dataset_processors/gedit/display_results.py b/ais_bench/tools/dataset_processors/gedit/display_results.py similarity index 98% rename from tools/dataset_processors/gedit/display_results.py rename to ais_bench/tools/dataset_processors/gedit/display_results.py index 4c5d7388..a0e983d1 100644 --- a/tools/dataset_processors/gedit/display_results.py +++ b/ais_bench/tools/dataset_processors/gedit/display_results.py @@ -160,7 +160,7 @@ def _load_and_merge_jsonl(self, path_kind: "org_pred_path"): def main(): """主函数""" parser = argparse.ArgumentParser(description="显示gedit数据集的推理结果") - parser.add_argument("--config_path", help="配置文件路径") + parser.add_argument("--config_path", default="./multi_device_run_qwen_image_edit.py", help="配置文件路径") parser.add_argument("--timestamp_path", help="结果时间戳路径") args = parser.parse_args() From e1027dcde202d4c99cbbcfde0a8bf1389a16f918 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 27 Feb 2026 17:26:41 +0800 Subject: [PATCH 35/59] adaptor new third party path --- .../third_party/mindie_sd/qwenimage_edit/attn_layer.py | 4 ++-- .../qwenimage_edit/distributed/parallel_mgr.py | 8 ++++---- .../qwenimage_edit/pipeline_qwenimage_edit_plus.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py index 2d0e58e7..653f2408 100644 --- a/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py @@ -21,10 +21,10 @@ # from yunchang.comm.all_to_all import SeqAllToAll4D # from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D import logging -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.parallel_mgr import ( get_sequence_parallel_world_size, get_sequence_parallel_rank, get_sp_group diff --git a/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py index 0b6ef343..8a802e7e 100644 --- a/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py @@ -4,8 +4,8 @@ import torch.distributed as dist import torch_npu import logging -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator #--------- ljf ------------------- import torch @@ -308,14 +308,14 @@ def initialize_model_parallel( f"tensor_parallel_degree " f"({tensor_parallel_degree})" ) - + rank_generator: RankGenerator = RankGenerator( tensor_parallel_degree, sequence_parallel_degree, classifier_free_guidance_degree, "tp-sp-cfg", ) - + global _CFG if _CFG is not None: logging.error("classifier_free_guidance group is already initialized") diff --git a/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py index 3d5c4b3b..2ede9089 100644 --- a/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py @@ -25,15 +25,15 @@ from diffusers.models import AutoencoderKLQwenImage # from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from ais_bench.benchmark.models.local_models.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from ais_bench.third_party.mindie_sd.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput -from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( +from ais_bench.third_party.mindie_sd.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.parallel_mgr import ( get_sequence_parallel_world_size, get_classifier_free_guidance_world_size, get_classifier_free_guidance_rank, @@ -845,7 +845,7 @@ def __call__( if_cond=False, #----------------ljf------------- )[0] noise_pred = noise_pred[:, : latents.size(1)] - + else: with self.transformer.cache_context("cond"): noise_pred = self.transformer( @@ -914,7 +914,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 5375295c8034669b9ac954ab0fea9c4d876cef6c Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 28 Feb 2026 16:02:01 +0800 Subject: [PATCH 36/59] add result converter --- .../dataset_processors/gedit/convert_preds.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 ais_bench/tools/dataset_processors/gedit/convert_preds.py diff --git a/ais_bench/tools/dataset_processors/gedit/convert_preds.py b/ais_bench/tools/dataset_processors/gedit/convert_preds.py new file mode 100644 index 00000000..99396675 --- /dev/null +++ b/ais_bench/tools/dataset_processors/gedit/convert_preds.py @@ -0,0 +1,110 @@ +import os +import math +import copy +import argparse +import json +import csv +import tabulate + +from datasets import Dataset, load_from_disk + +from ais_bench.benchmark.configs.datasets.needlebench_v2.needlebench_v2_4k.needlebench_v2_multi_reasoning_4k import language +from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.cli.config_manager import CustomConfigChecker +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.utils.logging.logger import AISLogger +from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list +from ais_bench.benchmark.datasets.g_edit import GEditDataset +from mmengine.config import Config + +logger = AISLogger(__name__) + +def load_gedit_dataset(path): + path = get_data_path(path) + return load_from_disk(path) + + +def load_config(config_path: str) -> Config: + """加载配置文件并进行校验""" + if not os.path.exists(config_path): + raise ParameterValueError(f"Config path: {config_path} is not exist!") + try: + config = Config.fromfile(config_path, format_python_code=False) + except BaseException as e: + raise RuntimeError(f"Fail to load config {config_path}, failed reason: {e}") + CustomConfigChecker(config, config_path).check() + return config + +class GEditPredsParser: + def __init__(self, args): + self.config = load_config(args.config_path) + self.output_dir = args.timestamp_path + self.dataset = load_gedit_dataset(args.dataset_path) + self.paths_map = dict( + org_pred_path = [], + ) + for comb in self.config["model_dataset_combinations"]: + model_abbr = comb["models"][0]["abbr"] + dataset_org_abbr = comb["datasets"][0]["abbr"] + self.paths_map["org_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_org_abbr}.jsonl")) + + def parse_results(self): + logger.info(f"Start parse infer result from: {self.output_dir}") + org_pred_data_list = self._load_and_merge_jsonl("org_pred_path") + org_pred_data_dict = {item["uuid"]: item for item in org_pred_data_list} + + self.all_data_results = {} + + for uuid in org_pred_data_dict.keys(): + id = org_pred_data_dict[uuid]["id"] + output_img_path = org_pred_data_dict[uuid]["prediction"] + self.all_data_results[id] = { + "key": self.dataset[id]["key"], + "task_type": self.dataset[id]["task_type"], + "instruction_language": self.dataset[id]["instruction_language"], + "output_img_path": output_img_path, + } + + logger.info(f"Finish parsing results") + + def dump_gedit_format_result(self): + save_path = os.path.join(self.output_dir, "results", "fullset") + logger.info(f"Start dumping gedit format result ......") + for id, item in self.all_data_results.items(): + dump_dir = os.path.join(save_path, item["task_type"], item["instruction_language"]) + if not os.path.exists(dump_dir): + os.makedirs(dump_dir) + # 将output_img_path copy到dump_dir + os.system(f"cp {item['output_img_path']} {os.path.join(dump_dir, item['key'] + '.png')}") + logger.info(f"Finish dumping gedit format result ......") + + def _load_and_merge_jsonl(self, path_kind: "org_pred_path"): + merged_data = [] + start_index = 0 + for path in self.paths_map[path_kind]: + offset_index = copy.deepcopy(start_index) + with open(path, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + data["prediction"] = os.path.join(os.path.dirname(path), data["prediction"]) + data["id"] = data["id"] + offset_index + start_index += 1 + merged_data.append(data) + return merged_data + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="显示gedit数据集的推理结果") + parser.add_argument("--config_path", default="./multi_device_run_qwen_image_edit.py", help="配置文件路径") + parser.add_argument("--timestamp_path", help="结果时间戳路径") + parser.add_argument("--dataset_path", default="ais_bench/datasets/GEdit-Bench") + + args = parser.parse_args() + eval_parser = GEditPredsParser(args) + eval_parser.parse_results() + eval_parser.dump_gedit_format_result() + + +if __name__ == "__main__": + main() From 3b367e68bbf3b2407a872db21a20a1bd9d44f59d Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 28 Feb 2026 16:06:31 +0800 Subject: [PATCH 37/59] add result converter --- ais_bench/tools/dataset_processors/gedit/convert_preds.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ais_bench/tools/dataset_processors/gedit/convert_preds.py b/ais_bench/tools/dataset_processors/gedit/convert_preds.py index 99396675..982fa4ea 100644 --- a/ais_bench/tools/dataset_processors/gedit/convert_preds.py +++ b/ais_bench/tools/dataset_processors/gedit/convert_preds.py @@ -5,6 +5,7 @@ import json import csv import tabulate +from tqdm import tqdm from datasets import Dataset, load_from_disk @@ -55,7 +56,7 @@ def parse_results(self): self.all_data_results = {} - for uuid in org_pred_data_dict.keys(): + for uuid in tqdm(org_pred_data_dict.keys(), desc="Parsing results"): id = org_pred_data_dict[uuid]["id"] output_img_path = org_pred_data_dict[uuid]["prediction"] self.all_data_results[id] = { @@ -70,7 +71,7 @@ def parse_results(self): def dump_gedit_format_result(self): save_path = os.path.join(self.output_dir, "results", "fullset") logger.info(f"Start dumping gedit format result ......") - for id, item in self.all_data_results.items(): + for id, item in tqdm(self.all_data_results.items(), desc="Dumping gedit format results"): dump_dir = os.path.join(save_path, item["task_type"], item["instruction_language"]) if not os.path.exists(dump_dir): os.makedirs(dump_dir) From c6328cb4dd5866e79195171f3ebafa755029de25 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 28 Feb 2026 16:10:07 +0800 Subject: [PATCH 38/59] add result converter --- .../dataset_processors/gedit/convert_preds.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ais_bench/tools/dataset_processors/gedit/convert_preds.py b/ais_bench/tools/dataset_processors/gedit/convert_preds.py index 982fa4ea..b625e787 100644 --- a/ais_bench/tools/dataset_processors/gedit/convert_preds.py +++ b/ais_bench/tools/dataset_processors/gedit/convert_preds.py @@ -5,6 +5,7 @@ import json import csv import tabulate +import shutil from tqdm import tqdm from datasets import Dataset, load_from_disk @@ -14,8 +15,6 @@ from ais_bench.benchmark.cli.config_manager import CustomConfigChecker from ais_bench.benchmark.datasets.utils.datasets import get_data_path from ais_bench.benchmark.utils.logging.logger import AISLogger -from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list -from ais_bench.benchmark.datasets.g_edit import GEditDataset from mmengine.config import Config logger = AISLogger(__name__) @@ -40,7 +39,12 @@ class GEditPredsParser: def __init__(self, args): self.config = load_config(args.config_path) self.output_dir = args.timestamp_path - self.dataset = load_gedit_dataset(args.dataset_path) + dataset = load_gedit_dataset(args.dataset_path) + # 将Dataset转换为字典以提高访问速度 + self.dataset = {} + for i in range(len(dataset)): + item = dataset[i] + self.dataset[item["id"]] = item self.paths_map = dict( org_pred_path = [], ) @@ -73,10 +77,9 @@ def dump_gedit_format_result(self): logger.info(f"Start dumping gedit format result ......") for id, item in tqdm(self.all_data_results.items(), desc="Dumping gedit format results"): dump_dir = os.path.join(save_path, item["task_type"], item["instruction_language"]) - if not os.path.exists(dump_dir): - os.makedirs(dump_dir) + os.makedirs(dump_dir, exist_ok=True) # 将output_img_path copy到dump_dir - os.system(f"cp {item['output_img_path']} {os.path.join(dump_dir, item['key'] + '.png')}") + shutil.copy(item['output_img_path'], os.path.join(dump_dir, item['key'] + '.png')) logger.info(f"Finish dumping gedit format result ......") def _load_and_merge_jsonl(self, path_kind: "org_pred_path"): From ee788671dd6f9f143a333e24f3cf388a09e86557 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 28 Feb 2026 16:13:07 +0800 Subject: [PATCH 39/59] add result converter --- ais_bench/tools/dataset_processors/gedit/convert_preds.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ais_bench/tools/dataset_processors/gedit/convert_preds.py b/ais_bench/tools/dataset_processors/gedit/convert_preds.py index b625e787..5be5521e 100644 --- a/ais_bench/tools/dataset_processors/gedit/convert_preds.py +++ b/ais_bench/tools/dataset_processors/gedit/convert_preds.py @@ -44,7 +44,8 @@ def __init__(self, args): self.dataset = {} for i in range(len(dataset)): item = dataset[i] - self.dataset[item["id"]] = item + # 使用索引作为id,因为Dataset中可能没有'id'键 + self.dataset[i] = item self.paths_map = dict( org_pred_path = [], ) From f82836e64135023b95f3f85e0027314d1f7b231e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Sat, 28 Feb 2026 16:14:40 +0800 Subject: [PATCH 40/59] add result converter --- ais_bench/tools/dataset_processors/gedit/convert_preds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/tools/dataset_processors/gedit/convert_preds.py b/ais_bench/tools/dataset_processors/gedit/convert_preds.py index 5be5521e..46795bcc 100644 --- a/ais_bench/tools/dataset_processors/gedit/convert_preds.py +++ b/ais_bench/tools/dataset_processors/gedit/convert_preds.py @@ -42,7 +42,7 @@ def __init__(self, args): dataset = load_gedit_dataset(args.dataset_path) # 将Dataset转换为字典以提高访问速度 self.dataset = {} - for i in range(len(dataset)): + for i in tqdm(range(len(dataset)), desc="Converting dataset to dictionary"): item = dataset[i] # 使用索引作为id,因为Dataset中可能没有'id'键 self.dataset[i] = item From bd04af104bb528e4c518742cc0ebeb0c84b6797a Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Mon, 2 Mar 2026 19:31:39 +0800 Subject: [PATCH 41/59] useful config --- .../configs/lmm_exmaple/multi_device_run_qwen_image_edit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py index 33431000..5857c4ba 100644 --- a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -5,7 +5,11 @@ from ais_bench.benchmark.configs.summarizers.example import summarizer from ais_bench.benchmark.configs.datasets.gedit.gedit_gen_0_shot_llmjudge import gedit_datasets +# ====== 用户需要配置参数 ========= +qwen_image_edit_models[0]["path"] = "/path/to/Qwen-Image-Edit-2509/" # 请根据实际情况修改权重路径 +qwen_image_edit_models[0]["infer_kwargs"]["num_inference_steps"] = 50 # 请根据实际情况修改推理步数 device_list = [0] # [0, 1, 2, 3] +# ====== 用户需要配置参数 ========= datasets = [] models = [] From 723807731baa0667cf298697c1c71b5724463d42 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 10:27:36 +0800 Subject: [PATCH 42/59] mv org third party --- .../mindie_sd}/qwenimage_edit/__init__.py | 0 .../mindie_sd}/qwenimage_edit/attn_layer.py | 4 ++-- .../mindie_sd}/qwenimage_edit/distributed/__init__.py | 0 .../qwenimage_edit/distributed/all_to_all.py | 0 .../qwenimage_edit/distributed/group_coordinator.py | 0 .../qwenimage_edit/distributed/parallel_mgr.py | 8 ++++---- .../mindie_sd}/qwenimage_edit/distributed/utils.py | 0 .../qwenimage_edit/pipeline_qwenimage_edit_plus.py | 10 +++++----- .../scheduling_flow_match_euler_discrete.py | 0 .../mindie_sd}/qwenimage_edit/transformer_qwenimage.py | 0 10 files changed, 11 insertions(+), 11 deletions(-) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/__init__.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/attn_layer.py (97%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/__init__.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/all_to_all.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/group_coordinator.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/parallel_mgr.py (97%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/distributed/utils.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/pipeline_qwenimage_edit_plus.py (99%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/scheduling_flow_match_euler_discrete.py (100%) rename ais_bench/{benchmark/models/local_models => third_party/mindie_sd}/qwenimage_edit/transformer_qwenimage.py (100%) diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/__init__.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/__init__.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py similarity index 97% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py index 2d0e58e7..653f2408 100644 --- a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/attn_layer.py @@ -21,10 +21,10 @@ # from yunchang.comm.all_to_all import SeqAllToAll4D # from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D import logging -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.parallel_mgr import ( get_sequence_parallel_world_size, get_sequence_parallel_rank, get_sp_group diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/__init__.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/__init__.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/all_to_all.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/all_to_all.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/group_coordinator.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/group_coordinator.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py similarity index 97% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py index 0b6ef343..8a802e7e 100644 --- a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/parallel_mgr.py @@ -4,8 +4,8 @@ import torch.distributed as dist import torch_npu import logging -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator #--------- ljf ------------------- import torch @@ -308,14 +308,14 @@ def initialize_model_parallel( f"tensor_parallel_degree " f"({tensor_parallel_degree})" ) - + rank_generator: RankGenerator = RankGenerator( tensor_parallel_degree, sequence_parallel_degree, classifier_free_guidance_degree, "tp-sp-cfg", ) - + global _CFG if _CFG is not None: logging.error("classifier_free_guidance group is already initialized") diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/utils.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/distributed/utils.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py similarity index 99% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py index 3d5c4b3b..2ede9089 100644 --- a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py +++ b/ais_bench/third_party/mindie_sd/qwenimage_edit/pipeline_qwenimage_edit_plus.py @@ -25,15 +25,15 @@ from diffusers.models import AutoencoderKLQwenImage # from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from ais_bench.benchmark.models.local_models.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from ais_bench.third_party.mindie_sd.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput -from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel -from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( +from ais_bench.third_party.mindie_sd.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel +from ais_bench.third_party.mindie_sd.qwenimage_edit.distributed.parallel_mgr import ( get_sequence_parallel_world_size, get_classifier_free_guidance_world_size, get_classifier_free_guidance_rank, @@ -845,7 +845,7 @@ def __call__( if_cond=False, #----------------ljf------------- )[0] noise_pred = noise_pred[:, : latents.size(1)] - + else: with self.transformer.cache_context("cond"): noise_pred = self.transformer( @@ -914,7 +914,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/scheduling_flow_match_euler_discrete.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/scheduling_flow_match_euler_discrete.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py b/ais_bench/third_party/mindie_sd/qwenimage_edit/transformer_qwenimage.py similarity index 100% rename from ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py rename to ais_bench/third_party/mindie_sd/qwenimage_edit/transformer_qwenimage.py From af1e570b6e02d5c0ed10af37da7137ddc6df0aec Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 10:56:41 +0800 Subject: [PATCH 43/59] process description fix --- ais_bench/benchmark/datasets/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index 1368bbf9..d014ff63 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -168,7 +168,7 @@ def load(self, predictions_path: str, **kwargs): self.update_task_state( { "total_count": len(futures), - "progress_description": "Infer progress", + "progress_description": "Processing predictions", "finish_count": i + 1, } ) @@ -195,7 +195,7 @@ def load(self, predictions_path: str, **kwargs): self.update_task_state( { "total_count": len(futures), - "progress_description": "Infer progress", + "progress_description": "Processing predictions", "finish_count": i + 1, } ) From 8da57058e5b30a6efe0d5bcf6dcf4028c045b93e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 18:26:05 +0800 Subject: [PATCH 44/59] remove copy from org result --- ais_bench/benchmark/cli/workers.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 945e3f1e..4c28f174 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -282,7 +282,6 @@ def do_work(self, cfg: ConfigDict): runner(task_part) else: runner(tasks) - self._result_post_process(tasks, cfg) logger.info("Evaluation tasks completed.") def _cfg_pre_process(self, cfg: ConfigDict) -> None: @@ -308,20 +307,6 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): if task["datasets"][0][0].get("judge_infer_cfg"): task["datasets"][0][0].pop("judge_infer_cfg") - def _result_post_process(self, tasks, cfg: ConfigDict): - # Copy judge infer result to normal name - - for task in tasks: - if task["datasets"][0][0]["abbr"] in self.org_dataset_abbrs.keys(): - cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') - final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') - if os.path.exists(final_org_results_path): - os.remove(final_org_results_path) - - if os.path.exists(cur_results_path): - # 基于cur_results_path的文件复制一份final_org_results_path - shutil.copy(cur_results_path, final_org_results_path) - class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: From 8de36e4a83050394f44709fe7ede21879e1193b4 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 18:34:58 +0800 Subject: [PATCH 45/59] fix conifg device --- .../configs/lmm_exmaple/multi_device_run_qwen_image_edit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py index 5857c4ba..969f35ab 100644 --- a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -15,11 +15,11 @@ models = [] model_dataset_combinations = [] -for i in device_list: +for i, device_id in enumerate(device_list): model_config = {k: v for k, v in qwen_image_edit_models[0].items()} model_config['abbr'] = f"{model_config['abbr']}-{i}" model_config['device_kwargs'] = dict(model_config['device_kwargs']) - model_config['device_kwargs']['device_id'] = i + model_config['device_kwargs']['device_id'] = device_id models.append(model_config) dataset_configs = [] From 6d24a50bdda6e471d91b07b4ea9452c5d5b3548f Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 19:47:53 +0800 Subject: [PATCH 46/59] fix --- ais_bench/benchmark/models/output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais_bench/benchmark/models/output.py b/ais_bench/benchmark/models/output.py index 935466f5..aaf7379b 100644 --- a/ais_bench/benchmark/models/output.py +++ b/ais_bench/benchmark/models/output.py @@ -181,7 +181,7 @@ def update_extra_details_data_from_text_response(self, text_response: dict) -> N return # only one message is allowed -LLM_META_DATA_TYPE = Union[Image, str] +LLM_META_DATA_TYPE = Union[Image.Image, str] class LMMOutput(Output): From 2346a5211ab563a7bb632589a6d6cfd4b96496fd Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 19:57:02 +0800 Subject: [PATCH 47/59] fix --- ais_bench/benchmark/datasets/utils/llm_judge.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index 027afd44..07c8e0df 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,11 +1,3 @@ -''' -Author: SJTUyh yh_silence@alumni.sjtu.edu.cn -Date: 2026-03-03 10:30:00 -LastEditors: SJTUyh yh_silence@alumni.sjtu.edu.cn -LastEditTime: 2026-03-03 10:32:07 -FilePath: \benchmark\ais_bench\benchmark\datasets\utils\llm_judge.py -Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE -''' import re import os from PIL import Image From 3186fcb7432e55dd5629a5f34364775c6f90c515 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 20:04:08 +0800 Subject: [PATCH 48/59] fix ut --- ais_bench/benchmark/models/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ais_bench/benchmark/models/__init__.py b/ais_bench/benchmark/models/__init__.py index 12230bf1..5908d946 100644 --- a/ais_bench/benchmark/models/__init__.py +++ b/ais_bench/benchmark/models/__init__.py @@ -14,5 +14,4 @@ from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401 -from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel -from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel \ No newline at end of file +from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel \ No newline at end of file From 4ed77d35d3c938152424a29b5452e70514139d86 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 20:51:25 +0800 Subject: [PATCH 49/59] fix ut --- ais_bench/benchmark/utils/prompt/prompt.py | 8 +- tests/UT/tasks/test_openicl_api_infer.py | 3 + tests/UT/tasks/test_openicl_eval.py | 623 +++++++++++---------- 3 files changed, 323 insertions(+), 311 deletions(-) diff --git a/ais_bench/benchmark/utils/prompt/prompt.py b/ais_bench/benchmark/utils/prompt/prompt.py index 67351a26..b485854d 100644 --- a/ais_bench/benchmark/utils/prompt/prompt.py +++ b/ais_bench/benchmark/utils/prompt/prompt.py @@ -157,13 +157,13 @@ def format_mm(self, **kwargs) -> PromptList: if item.startswith(AIS_TEXT_START): question = item.replace(AIS_TEXT_START, "") question = {"question": question} - text_content = mm['text'].deepcopy() + text_content = deepcopy(mm['text']) text_content['text'] = safe_format(text_content['text'], **question) contents.append(text_content) elif item.startswith(AIS_IMAGE_START): image = item.replace(AIS_IMAGE_START, "") image = {"image": image} - image_content = mm['image'].deepcopy() + image_content = deepcopy(mm['image']) if isinstance(image_content['image_url'], dict): image_content['image_url']['url'] = safe_format(image_content['image_url']['url'], **image) else: @@ -172,7 +172,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_VIDEO_START): video = item.replace(AIS_VIDEO_START, "") video = {"video": video} - video_content = mm['video'].deepcopy() + video_content = deepcopy(mm['video']) if isinstance(video_content['video_url'], dict): video_content['video_url']['url'] = safe_format(video_content['video_url']['url'], **video) else: @@ -181,7 +181,7 @@ def format_mm(self, **kwargs) -> PromptList: elif item.startswith(AIS_AUDIO_START): audio = item.replace(AIS_AUDIO_START, "") audio = {"audio": audio} - audio_content = mm['audio'].deepcopy() + audio_content = deepcopy(mm['audio']) if isinstance(audio_content['audio_url'], dict): audio_content['audio_url']['url'] = safe_format(audio_content['audio_url']['url'], **audio) else: diff --git a/tests/UT/tasks/test_openicl_api_infer.py b/tests/UT/tasks/test_openicl_api_infer.py index b03cd567..98d8ef0c 100644 --- a/tests/UT/tasks/test_openicl_api_infer.py +++ b/tests/UT/tasks/test_openicl_api_infer.py @@ -126,6 +126,9 @@ def _create_task(self, cfg=None): if task.repeat > 1: task.logger.info(f'num_return_sequences is greater than 1, each data will be infer independently {task.repeat} times') + # 设置默认的task_state_manager,因为_get_data_list方法需要它 + task.task_state_manager = MagicMock() + return task @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') diff --git a/tests/UT/tasks/test_openicl_eval.py b/tests/UT/tasks/test_openicl_eval.py index 36a14643..ee2f52be 100644 --- a/tests/UT/tasks/test_openicl_eval.py +++ b/tests/UT/tasks/test_openicl_eval.py @@ -56,26 +56,26 @@ def tearDown(self): def _create_task(self, cfg=None): """创建OpenICLEvalTask实例的辅助方法 - + 修复dataset_cfgs类型问题:BaseTask中dataset_cfgs被设置为cfg["datasets"][0](ConfigDict), 但OpenICLEvalTask的源码中使用了sum([self.dataset_cfgs], [])和for循环,期望它是列表。 这是源码实现问题,但测试代码需要适配。 """ if cfg is None: cfg = self.cfg - + # 先调用BaseTask的初始化 task = OpenICLEvalTask.__new__(OpenICLEvalTask) # 调用BaseTask.__init__ from ais_bench.benchmark.tasks.base import BaseTask BaseTask.__init__(task, cfg) - + # 修复:将dataset_cfgs设置为列表,因为源码中使用了sum([self.dataset_cfgs], []) # 注意:这是源码实现问题,BaseTask中dataset_cfgs被设置为单个ConfigDict, # 但子类中使用了sum([self.dataset_cfgs], [])和for循环,说明源码期望它是列表 original_dataset_cfg = task.dataset_cfgs task.dataset_cfgs = [original_dataset_cfg] if not isinstance(original_dataset_cfg, list) else original_dataset_cfg - + # 继续OpenICLEvalTask的初始化 task.num_gpus = max( c.get('eval_cfg', {}).get('num_gpus', 0) @@ -85,7 +85,10 @@ def _create_task(self, cfg=None): task.cal_extract_rate = cfg.get('eval', {}).get('runner', {}).get( 'task', {}).get('cal_extract_rate', False) task.logger.debug(f"Dump details: {task.dump_details}, calculate extract rate: {task.cal_extract_rate}") - + + # 设置默认的task_state_manager,因为_score方法需要它 + task.task_state_manager = MagicMock() + return task @patch('ais_bench.benchmark.tasks.openicl_eval.AISLogger') @@ -93,9 +96,9 @@ def test_init(self, mock_logger_class): """测试OpenICLEvalTask初始化""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + self.assertEqual(task.name_prefix, "OpenICLEval") self.assertEqual(task.log_subdir, "logs/eval") self.assertEqual(task.output_subdir, "results") @@ -108,7 +111,7 @@ def test_init_with_details(self, mock_logger_class): """测试使用dump_details初始化""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + cfg = ConfigDict({ "models": [{"type": "test_model"}], "datasets": [{ @@ -128,9 +131,9 @@ def test_init_with_details(self, mock_logger_class): } } }) - + task = self._create_task(cfg) - + self.assertTrue(task.dump_details) self.assertTrue(task.cal_extract_rate) @@ -139,12 +142,12 @@ def test_get_command(self, mock_logger_class): """测试get_command方法""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + with patch('ais_bench.benchmark.tasks.openicl_eval.sys.executable', '/usr/bin/python'): cmd = task.get_command("/path/to/config.py", "CUDA_VISIBLE_DEVICES=0 {task_cmd}") - + self.assertIn("/usr/bin/python", cmd) self.assertIn("/path/to/config.py", cmd) @@ -156,12 +159,12 @@ def test_score(self, mock_signature, mock_build_dataset, mock_evaluators, mock_l """测试_score方法""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock signature函数,返回一个只包含preds中已有键的参数签名 from inspect import Parameter, Signature mock_sig = Signature([ @@ -171,33 +174,33 @@ def test_score(self, mock_signature, mock_build_dataset, mock_evaluators, mock_l Parameter('origin_prompt', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_dataset.test = MagicMock() mock_dataset.test.__len__ = lambda x: 100 mock_dataset.test.select = lambda x: mock_dataset.test mock_dataset.test.__getitem__ = lambda x, y: {"answer": "test"} mock_build_dataset.return_value = mock_dataset - + task = self._create_task() task.logger = mock_logger # 设置logger为mock - + # 修复:设置dataset_cfg,因为_score方法需要使用它 task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 修复:model_cfg需要abbr字段,否则model_abbr_from_cfg会需要path字段 if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 修复:task_abbr_from_cfg期望datasets是嵌套列表结构 [[dataset1, dataset2]] # 但BaseTask中datasets[0]是单个ConfigDict,所以需要修复cfg结构 # 注意:cfg["datasets"]已经是[[{...}]]结构,所以不需要修改 # 但需要确保task.cfg["datasets"]的结构正确 if not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -209,9 +212,10 @@ def test_score(self, mock_signature, mock_build_dataset, mock_evaluators, mock_l os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') - + + task.task_state_manager = MagicMock() task._score() - + mock_evaluators.build.assert_called_once() @patch('ais_bench.benchmark.tasks.openicl_eval.AISLogger') @@ -219,7 +223,7 @@ def test_score_with_invalid_k_n(self, mock_logger_class): """测试_score方法,无效的k和n值""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + cfg = ConfigDict({ "models": [{"type": "test_model"}], "datasets": [{ @@ -234,17 +238,18 @@ def test_score_with_invalid_k_n(self, mock_logger_class): "cli_args": {}, "eval": {"runner": {"task": {}}} }) - + task = self._create_task(cfg) - + # 修复:设置dataset_cfg,因为_score方法需要使用它 task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + + task.task_state_manager = MagicMock() with self.assertRaises(ParameterValueError) as context: task._score() - + error_code = context.exception.error_code_str self.assertEqual(error_code, TEVAL_CODES.N_K_ILLEGAL.full_code) @@ -253,7 +258,7 @@ def test_score_with_k_greater_than_n(self, mock_logger_class): """测试_score方法,k大于n""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + cfg = ConfigDict({ "models": [{"type": "test_model"}], "datasets": [{ @@ -268,17 +273,17 @@ def test_score_with_k_greater_than_n(self, mock_logger_class): "cli_args": {}, "eval": {"runner": {"task": {}}} }) - + task = self._create_task(cfg) - + # 修复:设置dataset_cfg,因为_score方法需要使用它 task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + with self.assertRaises(ParameterValueError) as context: task._score() - + error_code = context.exception.error_code_str self.assertEqual(error_code, TEVAL_CODES.N_K_ILLEGAL.full_code) @@ -290,23 +295,23 @@ def test_score_no_predictions(self, mock_task_abbr, mock_build_dataset, mock_log mock_logger = MagicMock() mock_logger_class.return_value = mock_logger mock_task_abbr.return_value = "test_task" - + mock_dataset = MagicMock() mock_dataset.test = MagicMock() mock_dataset.test.__len__ = lambda x: 100 mock_build_dataset.return_value = mock_dataset - + task = self._create_task() - + # 修复:设置dataset_cfg,因为_score方法需要使用它 task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 修复:model_cfg需要abbr字段,否则model_abbr_from_cfg会需要path字段 if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 修复:task_abbr_from_cfg期望datasets是嵌套列表结构 [[dataset1, dataset2]] # 但BaseTask中datasets[0]是单个ConfigDict,所以需要修复cfg结构 if isinstance(task.cfg["datasets"][0][0], dict) and not isinstance(task.cfg["datasets"][0], list): @@ -318,9 +323,9 @@ def test_extract_rate(self, mock_logger_class): """测试extract_rate方法""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + results = { "details": { "0": {"predictions": "test"}, @@ -328,9 +333,9 @@ def test_extract_rate(self, mock_logger_class): "2": {"predictions": "test2"} } } - + rate = task.extract_rate(results) - + self.assertIsInstance(rate, (int, float)) self.assertGreaterEqual(rate, 0) self.assertLessEqual(rate, 100) @@ -340,9 +345,9 @@ def test_format_details(self, mock_logger_class): """测试format_details方法""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1", "pred2"] references = ["ref1", "ref2"] details = [ @@ -353,7 +358,7 @@ def test_format_details(self, mock_logger_class): {"id": 0, "prediction": "pred1", "origin_prompt": "prompt1"}, {"id": 1, "prediction": "pred2", "origin_prompt": "prompt2"} ] - + result = task.format_details( predictions, [], @@ -362,7 +367,7 @@ def test_format_details(self, mock_logger_class): None, pred_dicts ) - + self.assertIsInstance(result, dict) self.assertIn("0", result) self.assertIn("1", result) @@ -373,9 +378,9 @@ def test_format_details_ppl(self, mock_logger_class): """测试format_details方法,PPL类型""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1"] references = ["ref1"] pred_dicts = [ @@ -385,7 +390,7 @@ def test_format_details_ppl(self, mock_logger_class): "label: test": {"BPB": 1.0} } ] - + result = task.format_details( predictions, [], @@ -394,7 +399,7 @@ def test_format_details_ppl(self, mock_logger_class): None, pred_dicts ) - + self.assertEqual(result["type"], "PPL") self.assertIn("0", result) @@ -403,9 +408,9 @@ def test_calculate_bpb(self, mock_logger_class): """测试calculate_bpb方法""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + pred_dicts = [ { "label: option1": {"BPB": 1.0}, @@ -418,9 +423,9 @@ def test_calculate_bpb(self, mock_logger_class): "label: option3": {"BPB": 3.5} } ] - + correct_bpb, incorrect_bpb = task.calculate_bpb(pred_dicts) - + self.assertIsInstance(correct_bpb, (int, float)) self.assertIsInstance(incorrect_bpb, (int, float)) self.assertGreater(correct_bpb, 0) @@ -431,30 +436,30 @@ def test_run_with_model_postprocessor(self, mock_logger_class): """测试run方法中使用model postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 确保model_cfg有abbr字段,避免KeyError: 'path' if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置model postprocessor task.model_cfg["pred_postprocessor"] = { "test_dataset": {"type": "test_postprocessor"} } - + # Mock dataset_abbr_from_cfg with patch('ais_bench.benchmark.tasks.openicl_eval.dataset_abbr_from_cfg') as mock_abbr: mock_abbr.return_value = "test_dataset" - + # Mock _score with patch.object(task, '_score') as mock_score: if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] - - task.run() - + + task.run(MagicMock()) + # 验证调用了_score mock_score.assert_called() @@ -463,25 +468,29 @@ def test_run_with_existing_output_file(self, mock_logger_class): """测试run方法中输出文件已存在的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 确保model_cfg有abbr字段,避免KeyError: 'path' if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # Mock文件存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists: mock_exists.return_value = True - + # Mock _score with patch.object(task, '_score') as mock_score: if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] - - task.run() - + + # 创建TaskStateManager的mock + mock_task_state_manager = MagicMock() + task.run(mock_task_state_manager) + + # 验证task_state_manager被正确设置 + self.assertEqual(task.task_state_manager, mock_task_state_manager) # 验证记录了警告日志 mock_logger.warning.assert_called() @@ -489,7 +498,7 @@ def test_parse_args(self): """测试parse_args函数""" import sys from unittest.mock import patch - + test_args = ['test_script', 'config.py'] with patch.object(sys, 'argv', test_args): from ais_bench.benchmark.tasks.openicl_eval import parse_args @@ -503,47 +512,47 @@ def test_score_with_num_prompts(self, mock_logger_class, mock_evaluators, mock_b """测试_score方法中num_prompts处理的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + # 设置reader_cfg中的test_range来模拟num_prompts的效果 cfg = self.cfg.copy() cfg["datasets"][0][0]["reader_cfg"]["test_range"] = "[:5]" - + task = self._create_task(cfg) task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset - test_range会在build_dataset_from_cfg时处理,所以test_set已经是限制后的 mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=5) # 限制后的数量 mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock prediction file不存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists: mock_exists.return_value = False - + task._score() - + # 验证build_dataset_from_cfg被调用(test_range会在那里处理) mock_build_dataset.assert_called() @@ -555,10 +564,10 @@ def test_score_with_dataset_postprocessor(self, mock_logger_class, mock_evaluato """测试_score方法中使用dataset_postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] @@ -566,15 +575,15 @@ def test_score_with_dataset_postprocessor(self, mock_logger_class, mock_evaluato task.eval_cfg = task.dataset_cfg.get('eval_cfg', {}) task.eval_cfg['dataset_postprocessor'] = {"type": "test_postprocessor"} task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=10) @@ -582,22 +591,22 @@ def test_score_with_dataset_postprocessor(self, mock_logger_class, mock_evaluato mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock postprocessor mock_postprocessor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_postprocessor - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock prediction file不存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists: mock_exists.return_value = False - + task._score() - + # 验证test_set.map被调用(dataset_postprocessor处理) mock_test_set.map.assert_called() @@ -609,37 +618,37 @@ def test_score_with_partial_filename(self, mock_logger_class, mock_evaluators, m """测试_score方法中使用partial_filename的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=10) mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock signature - 需要mock signature函数 from inspect import Signature, Parameter mock_sig = Signature([ @@ -647,18 +656,18 @@ def test_score_with_partial_filename(self, mock_logger_class, mock_evaluators, m Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), Parameter('test_set', Parameter.POSITIONAL_OR_KEYWORD), ]) - + # Mock prediction file: 主文件不存在,但partial文件存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists, \ patch('ais_bench.benchmark.tasks.openicl_eval.mmengine.load') as mock_load, \ patch('ais_bench.benchmark.tasks.openicl_eval.signature') as mock_signature_func: - + # 设置signature返回值 - 需要返回一个Signature对象,其parameters属性包含需要的参数 def signature_side_effect(func): return mock_sig - + mock_signature_func.side_effect = signature_side_effect - + # 模拟partial文件存在的情况 # 注意:源码中先检查主文件,然后检查partial_filename (_0.jsonl) # 如果主文件不存在但partial文件存在,会进入else分支 @@ -674,7 +683,7 @@ def exists_side_effect(path): return True # 下一个文件不存在(_1.jsonl),退出循环 return False - + mock_exists.side_effect = exists_side_effect # mock_load需要返回包含prediction字段的字典 # 注意:源码中会遍历sub_preds的键,所以需要确保键是字符串数字 @@ -682,9 +691,9 @@ def exists_side_effect(path): "0": {"prediction": "pred1", "id": 0}, "1": {"prediction": "pred2", "id": 1} } - + task._score() - + # 验证mmengine.load被调用(加载partial文件) mock_load.assert_called() @@ -693,10 +702,10 @@ def test_extract_rate_with_keyerror(self, mock_logger_class): """测试extract_rate方法中KeyError异常处理""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 创建包含无效数据的details results = { "details": { @@ -704,7 +713,7 @@ def test_extract_rate_with_keyerror(self, mock_logger_class): "1": {} # 缺少predictions键,会触发KeyError } } - + # 应该抛出KeyError with self.assertRaises(KeyError): task.extract_rate(results) @@ -714,9 +723,9 @@ def test_format_details_with_model_details(self, mock_logger_class): """测试format_details方法中有model_details的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1", "pred2"] model_pred_strs = ["model_pred1", "model_pred2"] references = ["ref1", "ref2"] @@ -732,7 +741,7 @@ def test_format_details_with_model_details(self, mock_logger_class): {"origin_prompt": "prompt1", "prediction": "pred1"}, {"origin_prompt": "prompt2", "prediction": "pred2"} ] - + result = task.format_details( predictions, model_pred_strs, @@ -741,7 +750,7 @@ def test_format_details_with_model_details(self, mock_logger_class): model_details, pred_dicts ) - + self.assertIsInstance(result, dict) self.assertIn("0", result) self.assertIn("1", result) @@ -757,47 +766,47 @@ def test_score_with_model_pred_postprocessor(self, mock_logger_class, mock_evalu """测试_score方法中使用model pred_postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr和pred_postprocessor if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" task.model_cfg["pred_postprocessor"] = {"type": "test_postprocessor"} - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=10) mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock prediction file不存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists, \ patch('ais_bench.benchmark.tasks.openicl_eval.TEXT_POSTPROCESSORS') as mock_postprocessors: mock_exists.return_value = False mock_postprocessor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_postprocessor - + task._score() - + # 验证记录了debug日志 mock_logger.debug.assert_called() @@ -808,10 +817,10 @@ def test_score_with_eval_pred_postprocessor(self, mock_logger_class, mock_evalua """测试_score方法中使用eval pred_postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] @@ -819,36 +828,36 @@ def test_score_with_eval_pred_postprocessor(self, mock_logger_class, mock_evalua task.eval_cfg = task.dataset_cfg.get('eval_cfg', {}) task.eval_cfg['pred_postprocessor'] = {"type": "test_postprocessor"} task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=10) mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock prediction file不存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists, \ patch('ais_bench.benchmark.tasks.openicl_eval.TEXT_POSTPROCESSORS') as mock_postprocessors: mock_exists.return_value = False mock_postprocessor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_postprocessor - + task._score() - + # 验证记录了debug日志 mock_logger.debug.assert_called() @@ -859,10 +868,10 @@ def test_score_with_sc_size(self, mock_logger_class, mock_evaluators, mock_build """测试_score方法中使用self-consistency (sc_size)的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() task.logger = mock_logger - + # 设置必要的属性 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] @@ -870,33 +879,33 @@ def test_score_with_sc_size(self, mock_logger_class, mock_evaluators, mock_build task.eval_cfg = task.dataset_cfg.get('eval_cfg', {}) task.eval_cfg['sc_size'] = 3 # 设置self-consistency size task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 设置model_cfg的abbr if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 设置datasets结构 if isinstance(task.cfg["datasets"][0], dict) and not isinstance(task.cfg["datasets"][0], list): task.cfg["datasets"] = [task.cfg["datasets"]] - + # Mock dataset mock_test_set = MagicMock() mock_test_set.__len__ = MagicMock(return_value=10) mock_dataset = MagicMock() mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + # Mock evaluator mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + # Mock prediction file不存在 with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists: mock_exists.return_value = False - + task._score() - + # 验证eval_cfg中有sc_size self.assertEqual(task.eval_cfg.get('sc_size'), 3) @@ -905,16 +914,16 @@ def test_format_details_without_details(self, mock_logger_class): """测试format_details方法中没有details的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1", "pred2"] references = ["ref1", "ref2"] pred_dicts = [ {"origin_prompt": "prompt1", "prediction": "pred1"}, {"origin_prompt": "prompt2", "prediction": "pred2"} ] - + result = task.format_details( predictions, [], @@ -923,7 +932,7 @@ def test_format_details_without_details(self, mock_logger_class): None, pred_dicts ) - + self.assertIsInstance(result, dict) self.assertIn("0", result) self.assertIn("1", result) @@ -936,7 +945,7 @@ def test_init_with_details(self, mock_logger_class): """测试__init__方法中dump_details和cal_extract_rate的设置""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + cfg = ConfigDict({ "models": [{ "type": "test_model" @@ -956,13 +965,13 @@ def test_init_with_details(self, mock_logger_class): } } }) - + task = OpenICLEvalTask(cfg) # 修复:确保dataset_cfgs是列表,因为__init__中使用了sum([self.dataset_cfgs], []) # BaseTask会将datasets[0]设置为dataset_cfgs,如果datasets[0]是列表,则dataset_cfgs就是列表 if not isinstance(task.dataset_cfgs, list): task.dataset_cfgs = [task.dataset_cfgs] - + self.assertTrue(task.dump_details) self.assertTrue(task.cal_extract_rate) self.assertEqual(task.num_gpus, 2) @@ -975,10 +984,10 @@ def test_parse_args(self, mock_parser_class): mock_args = MagicMock() mock_args.config = "test_config.py" mock_parser.parse_args.return_value = mock_args - + from ais_bench.benchmark.tasks.openicl_eval import parse_args args = parse_args() - + mock_parser.add_argument.assert_called_once_with('config', help='Config file path') mock_parser.parse_args.assert_called_once() self.assertEqual(args.config, "test_config.py") @@ -993,28 +1002,28 @@ def test_score_with_dataset_postprocessor(self, mock_postprocessors, mock_signat """测试_score方法中使用dataset_postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + # Mock postprocessor mock_processor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 mock_test_set.select = lambda x: mock_test_set - + # Mock map方法,使其调用传入的函数 def mock_map(func): # 模拟map的行为:调用函数处理每个样本 @@ -1022,12 +1031,12 @@ def mock_map(func): sample = {"answer": "test"} func(sample) # 调用postprocess函数,这会调用proc(即mock_processor) return mock_test_set - + mock_test_set.map = mock_map mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["dataset_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1035,10 +1044,10 @@ def mock_map(func): task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1051,9 +1060,9 @@ def mock_map(func): with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证postprocessor被调用 # TEXT_POSTPROCESSORS.get应该被调用 mock_postprocessors.get.assert_called_once_with("test_processor") @@ -1069,22 +1078,22 @@ def test_score_with_model_pred_postprocessor(self, mock_postprocessors, mock_sig """测试_score方法中使用model_cfg['pred_postprocessor']的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1092,26 +1101,26 @@ def test_score_with_model_pred_postprocessor(self, mock_postprocessors, mock_sig mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["models"][0]["pred_postprocessor"] = {"type": "test_processor", "param": "value"} task = self._create_task(cfg) task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 创建预测文件(使用二进制模式,匹配代码中的读取方式) pred_file = os.path.join(self.temp_dir, "predictions", "test_model", "test_dataset.jsonl") os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证postprocessor被调用 self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -1124,22 +1133,22 @@ def test_score_with_eval_pred_postprocessor(self, mock_postprocessors, mock_sign """测试_score方法中使用eval_cfg['pred_postprocessor']的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value="processed") mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1147,7 +1156,7 @@ def test_score_with_eval_pred_postprocessor(self, mock_postprocessors, mock_sign mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["pred_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1155,10 +1164,10 @@ def test_score_with_eval_pred_postprocessor(self, mock_postprocessors, mock_sign task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1171,9 +1180,9 @@ def test_score_with_eval_pred_postprocessor(self, mock_postprocessors, mock_sign with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证postprocessor被调用 self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -1186,22 +1195,22 @@ def test_score_with_model_postprocessor(self, mock_postprocessors, mock_signatur """测试_score方法中使用model_postprocessor的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": []} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": []} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=["processed1", "processed2"]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1209,7 +1218,7 @@ def test_score_with_model_postprocessor(self, mock_postprocessors, mock_signatur mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["model_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1217,13 +1226,13 @@ def test_score_with_model_postprocessor(self, mock_postprocessors, mock_signatur task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # Mock test_set[self.output_column]返回列表 mock_test_set.__getitem__ = lambda x, y: ["ref1", "ref2"] if y == task.output_column else {"answer": "test"} - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1236,9 +1245,9 @@ def test_score_with_model_postprocessor(self, mock_postprocessors, mock_signatur with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证model_postprocessor被调用 self.assertTrue(mock_postprocessors.get.called or mock_processor.called) # 验证model_result被处理 @@ -1253,24 +1262,24 @@ def test_score_with_sc_size(self, mock_counter, mock_signature, mock_build_datas """测试_score方法中使用sc_size (self-consistency)的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + # Mock Counter for self-consistency mock_counter_instance = MagicMock() mock_counter_instance.most_common.return_value = [("test", 3)] mock_counter.return_value = mock_counter_instance - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1278,7 +1287,7 @@ def test_score_with_sc_size(self, mock_counter, mock_signature, mock_build_datas mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["sc_size"] = 3 task = self._create_task(cfg) @@ -1286,10 +1295,10 @@ def test_score_with_sc_size(self, mock_counter, mock_signature, mock_build_datas task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1302,9 +1311,9 @@ def test_score_with_sc_size(self, mock_counter, mock_signature, mock_build_datas with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": ["test1", "test2", "test3"]}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": ["test1", "test1", "test1"]}) + b'\n') - + task._score() - + # 验证Counter被调用(用于self-consistency) self.assertTrue(mock_counter.called) @@ -1316,19 +1325,19 @@ def test_score_with_returns_tool_calls(self, mock_signature, mock_build_dataset, """测试_score方法中使用returns_tool_calls的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1336,26 +1345,26 @@ def test_score_with_returns_tool_calls(self, mock_signature, mock_build_dataset, mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["models"][0]["returns_tool_calls"] = True task = self._create_task(cfg) task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 创建预测文件(使用二进制模式,匹配代码中的读取方式) pred_file = os.path.join(self.temp_dir, "predictions", "test_model", "test_dataset.jsonl") os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证evaluator配置中is_fc_model被设置 self.assertEqual(mock_evaluators.build.call_args[0][0].get('is_fc_model'), True) @@ -1367,19 +1376,19 @@ def test_score_with_origin_prompt_typeerror(self, mock_signature, mock_build_dat """测试_score方法中origin_prompt的TypeError处理""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1387,21 +1396,21 @@ def test_score_with_origin_prompt_typeerror(self, mock_signature, mock_build_dat mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + task = self._create_task() task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 创建预测文件(使用二进制模式,匹配代码中的读取方式),pred_strs是None,导致TypeError pred_file = os.path.join(self.temp_dir, "predictions", "test_model", "test_dataset.jsonl") os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0}) + b'\n') # 没有prediction字段 - + # 模拟pred_strs为None的情况 with patch.object(task, '_score', wraps=task._score) as mock_score: try: @@ -1418,19 +1427,19 @@ def test_score_with_dump_details(self, mock_postprocessors, mock_signature, mock """测试_score方法中dump_details为True的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": [{"pred": "test", "ref": "test"}]} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": [{"pred": "processed", "ref": "test"}]} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1438,7 +1447,7 @@ def test_score_with_dump_details(self, mock_postprocessors, mock_signature, mock mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["eval"]["runner"]["task"]["dump_details"] = True task = self._create_task(cfg) @@ -1446,10 +1455,10 @@ def test_score_with_dump_details(self, mock_postprocessors, mock_signature, mock task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1462,9 +1471,9 @@ def test_score_with_dump_details(self, mock_postprocessors, mock_signature, mock with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test", "origin_prompt": "prompt1"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2", "origin_prompt": "prompt2"}) + b'\n') - + task._score() - + # 验证dump_details相关逻辑被执行 self.assertTrue(mock_logger.warning.called or mock_logger.info.called) @@ -1476,20 +1485,20 @@ def test_score_with_bfcl_dataset(self, mock_signature, mock_build_dataset, mock_ """测试_score方法中BFCL数据集的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() # 确保evaluate返回包含details的结果,这样BFCL检查才能执行 mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": {}} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1497,7 +1506,7 @@ def test_score_with_bfcl_dataset(self, mock_signature, mock_build_dataset, mock_ mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["type"] = "BFCLDataset" cfg["eval"]["runner"]["task"]["dump_details"] = True @@ -1507,15 +1516,15 @@ def test_score_with_bfcl_dataset(self, mock_signature, mock_build_dataset, mock_ task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + # 确保dump_details正确设置 self.assertTrue(task.dump_details, "dump_details should be True") # 确保dataset_cfg的type正确设置 self.assertEqual(task.dataset_cfg.get("type"), "BFCLDataset", "dataset_cfg type should be BFCLDataset") - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1528,9 +1537,9 @@ def test_score_with_bfcl_dataset(self, mock_signature, mock_build_dataset, mock_ with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2"}) + b'\n') - + task._score() - + # 验证BFCL特殊处理逻辑被执行 # 检查logger.info是否被调用,并且包含BFCL相关的日志 # 注意:logger.info会在多个地方被调用(BFCL检查、任务结果记录等) @@ -1548,22 +1557,22 @@ def test_score_with_model_result(self, mock_postprocessors, mock_signature, mock """测试_score方法中model_result不为None的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": []} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": []} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=["processed1", "processed2"]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1571,7 +1580,7 @@ def test_score_with_model_result(self, mock_postprocessors, mock_signature, mock mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["model_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1579,13 +1588,13 @@ def test_score_with_model_result(self, mock_postprocessors, mock_signature, mock task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # Mock test_set[self.output_column]返回列表 mock_test_set.__getitem__ = lambda x, y: ["ref1", "ref2"] if y == task.output_column else {"answer": "test"} - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1598,9 +1607,9 @@ def test_score_with_model_result(self, mock_postprocessors, mock_signature, mock with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test", "origin_prompt": "prompt1"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2", "origin_prompt": "prompt2"}) + b'\n') - + task._score() - + # 验证model_result相关日志被调用 model_result_logs = [call for call in mock_logger.info.call_args_list if "Model Postprocess" in str(call)] self.assertTrue(len(model_result_logs) > 0) @@ -1614,22 +1623,22 @@ def test_score_with_pred_list_flag(self, mock_postprocessors, mock_signature, mo """测试_score方法中pred_list_flag为True的情况(prediction是列表)""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=["processed"]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1637,7 +1646,7 @@ def test_score_with_pred_list_flag(self, mock_postprocessors, mock_signature, mo mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["pred_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1645,10 +1654,10 @@ def test_score_with_pred_list_flag(self, mock_postprocessors, mock_signature, mo task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1661,9 +1670,9 @@ def test_score_with_pred_list_flag(self, mock_postprocessors, mock_signature, mo with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": ["test1", "test2"]}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": ["test3", "test4"]}) + b'\n') - + task._score() - + # 验证postprocessor被调用(列表形式) self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -1676,19 +1685,19 @@ def test_score_with_partial_filename(self, mock_mmengine, mock_signature, mock_b """测试_score方法中使用partial_filename的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1696,16 +1705,16 @@ def test_score_with_partial_filename(self, mock_mmengine, mock_signature, mock_b mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + task = self._create_task() task.logger = mock_logger # 设置logger为mock task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # Mock文件存在检查:主文件不存在,但partial文件存在 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1716,12 +1725,12 @@ def test_score_with_partial_filename(self, mock_mmengine, mock_signature, mock_b ) root, ext = os.path.splitext(pred_file) partial_filename = root + '_0' + ext - + # 创建partial文件 os.makedirs(os.path.dirname(partial_filename), exist_ok=True) # Mock mmengine.load返回partial predictions mock_mmengine.load.return_value = {"0": {"id": 0, "prediction": "test"}, "1": {"id": 1, "prediction": "test2"}} - + # Mock osp.exists with patch('ais_bench.benchmark.tasks.openicl_eval.osp.exists') as mock_exists: def exists_side_effect(path): @@ -1734,10 +1743,10 @@ def exists_side_effect(path): return True # 下一个文件不存在 return False - + mock_exists.side_effect = exists_side_effect task._score() - + # 验证mmengine.load被调用(用于加载partial文件) self.assertTrue(mock_mmengine.load.called) @@ -1750,19 +1759,19 @@ def test_score_with_dump_details_extract_rate(self, mock_postprocessors, mock_si """测试_score方法中dump_details和cal_extract_rate为True的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": [{"pred": "test", "ref": "test"}]} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": []} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1770,7 +1779,7 @@ def test_score_with_dump_details_extract_rate(self, mock_postprocessors, mock_si mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["eval"]["runner"]["task"]["dump_details"] = True cfg["eval"]["runner"]["task"]["cal_extract_rate"] = True @@ -1779,10 +1788,10 @@ def test_score_with_dump_details_extract_rate(self, mock_postprocessors, mock_si task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1795,9 +1804,9 @@ def test_score_with_dump_details_extract_rate(self, mock_postprocessors, mock_si with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test", "origin_prompt": "prompt1"}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2", "origin_prompt": "prompt2"}) + b'\n') - + task._score() - + # 验证extract_rate相关逻辑被执行 self.assertTrue(mock_logger.warning.called or mock_logger.info.called) @@ -1809,19 +1818,19 @@ def test_score_with_ppl_inferencer(self, mock_signature, mock_build_dataset, moc """测试_score方法中使用PPL inferencer的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": [{"pred": "test", "ref": "test"}]} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": []} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1829,7 +1838,7 @@ def test_score_with_ppl_inferencer(self, mock_signature, mock_build_dataset, moc mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["infer_cfg"]["inferencer"] = {"type": "PPLInferencer"} cfg["eval"]["runner"]["task"]["dump_details"] = True @@ -1838,10 +1847,10 @@ def test_score_with_ppl_inferencer(self, mock_signature, mock_build_dataset, moc task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1854,9 +1863,9 @@ def test_score_with_ppl_inferencer(self, mock_signature, mock_build_dataset, moc with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": "test", "origin_prediction": "orig", "label: option1": {"BPB": 1.0}}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": "test2", "origin_prediction": "orig2", "label: option1": {"BPB": 2.0}}) + b'\n') - + task._score() - + # 验证PPL相关逻辑被执行(calculate_bpb) self.assertTrue(mock_logger.warning.called or mock_logger.info.called) @@ -1869,22 +1878,22 @@ def test_score_with_model_postprocessor_pred_list_flag(self, mock_postprocessors """测试_score方法中使用model_postprocessor且pred_list_flag为True的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9, "details": []} mock_evaluator.score.return_value = {"accuracy": 0.95, "details": []} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=[["processed1"], ["processed2"]]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1892,7 +1901,7 @@ def test_score_with_model_postprocessor_pred_list_flag(self, mock_postprocessors mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["model_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) @@ -1900,13 +1909,13 @@ def test_score_with_model_postprocessor_pred_list_flag(self, mock_postprocessors task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # Mock test_set[self.output_column]返回列表 mock_test_set.__getitem__ = lambda x, y: ["ref1", "ref2"] if y == task.output_column else {"answer": "test"} - + # 使用get_infer_output_path获取正确的预测文件路径 from ais_bench.benchmark.utils.core.abbr import get_infer_output_path pred_file = get_infer_output_path( @@ -1919,9 +1928,9 @@ def test_score_with_model_postprocessor_pred_list_flag(self, mock_postprocessors with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": ["test1", "test2"]}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": ["test3", "test4"]}) + b'\n') - + task._score() - + # 验证model_postprocessor被调用(列表形式) self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -1934,22 +1943,22 @@ def test_score_with_model_pred_postprocessor_pred_list_flag(self, mock_postproce """测试_score方法中使用model_cfg['pred_postprocessor']且pred_list_flag为True的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=[["processed"]]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -1957,26 +1966,26 @@ def test_score_with_model_pred_postprocessor_pred_list_flag(self, mock_postproce mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["models"][0]["pred_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 创建预测文件(使用二进制模式,匹配代码中的读取方式),prediction是列表 pred_file = os.path.join(self.temp_dir, "predictions", "test_model", "test_dataset.jsonl") os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": ["test1", "test2"]}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": ["test3", "test4"]}) + b'\n') - + task._score() - + # 验证postprocessor被调用(列表形式) self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -1989,22 +1998,22 @@ def test_score_with_eval_pred_postprocessor_pred_list_flag(self, mock_postproces """测试_score方法中使用eval_cfg['pred_postprocessor']且pred_list_flag为True的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + mock_evaluator = MagicMock() mock_evaluator.evaluate.return_value = {"accuracy": 0.9} mock_evaluator.score.return_value = {"accuracy": 0.9} mock_evaluators.build.return_value = mock_evaluator - + from inspect import Parameter, Signature mock_sig = Signature([ Parameter('predictions', Parameter.POSITIONAL_OR_KEYWORD), Parameter('references', Parameter.POSITIONAL_OR_KEYWORD), ]) mock_signature.return_value = mock_sig - + mock_processor = MagicMock(return_value=[["processed"]]) mock_postprocessors.get.return_value = mock_processor - + mock_dataset = MagicMock() mock_test_set = MagicMock() mock_test_set.__len__ = lambda x: 2 @@ -2012,26 +2021,26 @@ def test_score_with_eval_pred_postprocessor_pred_list_flag(self, mock_postproces mock_test_set.__getitem__ = lambda x, y: {"answer": "test"} mock_dataset.test = mock_test_set mock_build_dataset.return_value = mock_dataset - + cfg = self.cfg.copy() cfg["datasets"][0][0]["eval_cfg"]["pred_postprocessor"] = {"type": "test_processor"} task = self._create_task(cfg) task.dataset_cfg = task.dataset_cfgs[0] task.eval_cfg = task.dataset_cfg.get('eval_cfg') task.output_column = task.dataset_cfg['reader_cfg']['output_column'] - + if "abbr" not in task.model_cfg: task.model_cfg["abbr"] = "test_model" - + # 创建预测文件(使用二进制模式,匹配代码中的读取方式),prediction是列表 pred_file = os.path.join(self.temp_dir, "predictions", "test_model", "test_dataset.jsonl") os.makedirs(os.path.dirname(pred_file), exist_ok=True) with open(pred_file, 'wb') as f: f.write(orjson.dumps({"id": 0, "prediction": ["test1", "test2"]}) + b'\n') f.write(orjson.dumps({"id": 1, "prediction": ["test3", "test4"]}) + b'\n') - + task._score() - + # 验证postprocessor被调用(列表形式) self.assertTrue(mock_postprocessors.get.called or mock_processor.called) @@ -2040,16 +2049,16 @@ def test_format_details_model_pred_strs_empty(self, mock_logger_class): """测试format_details方法中model_pred_strs为空时抛出异常的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1", "pred2"] model_pred_strs = [] # 空列表 references = ["ref1", "ref2"] details = [{"pred": "pred1", "answer": "ref1", "correct": True}] model_details = [{"pred": "model_pred1", "correct": True}] pred_dicts = [{"origin_prompt": "prompt1", "prediction": "pred1"}] - + # 应该抛出ParameterValueError from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError with self.assertRaises(ParameterValueError): @@ -2060,9 +2069,9 @@ def test_format_details_details_only(self, mock_logger_class): """测试format_details方法中只有details没有model_details的情况""" mock_logger = MagicMock() mock_logger_class.return_value = mock_logger - + task = self._create_task() - + predictions = ["pred1", "pred2"] model_pred_strs = None references = ["ref1", "ref2"] @@ -2072,9 +2081,9 @@ def test_format_details_details_only(self, mock_logger_class): {"origin_prompt": "prompt1", "prediction": "pred1"}, {"origin_prompt": "prompt2", "prediction": "pred2"} ] - + result = task.format_details(predictions, model_pred_strs, references, details, model_details, pred_dicts) - + # 验证返回结果 self.assertIsNotNone(result) self.assertEqual(result['type'], 'GEN') From 9284ee02c282e4b9a8d0b32f066b26111de28e01 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Tue, 3 Mar 2026 21:02:42 +0800 Subject: [PATCH 50/59] fix ut --- tests/UT/cli/test_workers.py | 168 +++++++++--------- .../test_bfcl_v3_output_handler.py | 7 +- .../test_gen_inferencer_output_handler.py | 71 ++++++-- .../test_ppl_inferencer_output_handler.py | 7 +- 4 files changed, 150 insertions(+), 103 deletions(-) diff --git a/tests/UT/cli/test_workers.py b/tests/UT/cli/test_workers.py index c1e4bc62..b8f3c8b1 100644 --- a/tests/UT/cli/test_workers.py +++ b/tests/UT/cli/test_workers.py @@ -29,18 +29,18 @@ def __init__(self, *args, **kwargs): self[key] = MockConfigDict(value) elif isinstance(value, list): self[key] = [MockConfigDict(item) if isinstance(item, dict) else item for item in value] - + def __getattr__(self, name): if name in self: return self[name] raise AttributeError(f"'MockConfigDict' object has no attribute '{name}'") - + def __setattr__(self, name, value): if isinstance(value, dict): self[name] = MockConfigDict(value) else: self[name] = value - + def merge_from_dict(self, data): for key, value in data.items(): if isinstance(value, dict) and key in self and isinstance(self[key], dict): @@ -54,7 +54,7 @@ def merge_from_dict(self, data): self[key] = MockConfigDict(value) else: self[key] = value - + def get(self, key, default=None): return super().get(key, default) @@ -69,7 +69,7 @@ def update_cfg(self, cfg): pass def do_work(self, cfg): pass - + worker = ConcreteWorker(mock_args) assert worker.args == mock_args @@ -89,7 +89,7 @@ def test_update_cfg_service_model(self, mock_fill_model_path, mock_get_config_ty """测试update_cfg方法,使用service模型""" # 设置mock返回值 mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLApiInferTask', 'MockLocalRunner'] - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'models': [{'attr': 'service', 'abbr': 'test_model'}], @@ -104,11 +104,11 @@ def test_update_cfg_service_model(self, mock_fill_model_path, mock_get_config_ty 'work_dir': '/test/workdir', 'cli_args': MagicMock(debug=False) }) - + # 执行测试 with patch('os.path.join', return_value='/test/workdir/predictions/'): result = self.infer_worker.update_cfg(cfg) - + # 验证结果 assert result == cfg assert cfg['infer']['partitioner']['type'] == 'MockNaivePartitioner' @@ -121,7 +121,7 @@ def test_update_cfg_service_model(self, mock_fill_model_path, mock_get_config_ty # 注意:在Infer.update_cfg中,prompt_template和ice_template不会被设置到retriever中 # 它们是在_fill_dataset_configs中设置的,而_fill_dataset_configs是在ConfigManager.load_config中调用的 # 所以这里不应该验证这些字段 - + # 注意:fill_model_path_if_datasets_need是在_fill_dataset_configs中调用的,不是在update_cfg中 # 所以这里不应该验证它被调用 @@ -131,7 +131,7 @@ def test_update_cfg_local_model(self, mock_fill_model_path, mock_get_config_type """测试update_cfg方法,使用local模型""" # 设置mock返回值 mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLInferTask', 'MockLocalRunner'] - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'models': [{'attr': 'local', 'abbr': 'test_model'}], @@ -144,11 +144,11 @@ def test_update_cfg_local_model(self, mock_fill_model_path, mock_get_config_type 'work_dir': '/test/workdir', 'cli_args': MagicMock(debug=True) }) - + # 执行测试 with patch('os.path.join', return_value='/test/workdir/predictions/'): self.infer_worker.update_cfg(cfg) - + # 验证结果 assert cfg['infer']['runner']['task']['type'] == 'MockOpenICLInferTask' assert cfg['infer']['runner']['debug'] == True # 应该从cli_args获取debug值 @@ -163,10 +163,10 @@ def test_do_work_no_merge(self, mock_logger, mock_runners, mock_partitioners): mock_partitioners.build.return_value = mock_partitioner mock_tasks = [MagicMock()] mock_partitioner.return_value = mock_tasks - + mock_runner = MagicMock() mock_runners.build.return_value = mock_runner - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'infer': { @@ -175,19 +175,19 @@ def test_do_work_no_merge(self, mock_logger, mock_runners, mock_partitioners): }, 'cli_args': MagicMock(merge_ds=False, mode='infer') }) - + # 模拟_update_tasks_cfg方法 with patch.object(self.infer_worker, '_update_tasks_cfg') as mock_update_tasks_cfg: # 执行测试 self.infer_worker.do_work(cfg) - + # 验证结果 mock_partitioners.build.assert_called_once_with(cfg['infer']['partitioner']) mock_partitioner.assert_called_once_with(cfg) mock_runners.build.assert_called_once_with(cfg['infer']['runner']) mock_runner.assert_called_once_with(mock_tasks) mock_update_tasks_cfg.assert_called_once_with(mock_tasks, cfg) - + # 验证正确的日志调用 logs_called = [call for call in mock_logger.info.call_args_list] assert call("Starting inference tasks...") in logs_called @@ -201,7 +201,7 @@ def test_do_work_merge_datasets(self, mock_logger, mock_runners, mock_partitione # 设置mock对象 mock_partitioner = MagicMock() mock_partitioners.build.return_value = mock_partitioner - + # 创建模拟任务 task1 = { 'models': [{'abbr': 'model1'}], @@ -213,10 +213,10 @@ def test_do_work_merge_datasets(self, mock_logger, mock_runners, mock_partitione } mock_tasks = [task1, task2] mock_partitioner.return_value = mock_tasks - + mock_runner = MagicMock() mock_runners.build.return_value = mock_runner - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'infer': { @@ -225,12 +225,12 @@ def test_do_work_merge_datasets(self, mock_logger, mock_runners, mock_partitione }, 'cli_args': MagicMock(merge_ds=True, mode='infer') }) - + # 模拟_update_tasks_cfg方法 with patch.object(self.infer_worker, '_update_tasks_cfg'): # 执行测试 self.infer_worker.do_work(cfg) - + # 验证结果 logs_called = [call for call in mock_logger.info.call_args_list] assert call("Merging datasets with the same model and inferencer...") in logs_called @@ -247,10 +247,10 @@ def test_do_work_perf_mode(self, mock_logger, mock_runners, mock_partitioners): mock_partitioners.build.return_value = mock_partitioner mock_tasks = [MagicMock()] mock_partitioner.return_value = mock_tasks - + mock_runner = MagicMock() mock_runners.build.return_value = mock_runner - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'infer': { @@ -259,14 +259,14 @@ def test_do_work_perf_mode(self, mock_logger, mock_runners, mock_partitioners): }, 'cli_args': MagicMock(merge_ds=False, mode='perf') }) - + # 模拟_update_tasks_cfg方法 with patch.object(self.infer_worker, '_update_tasks_cfg'): # 执行测试 with patch.object(self.infer_worker, '_merge_datasets') as mock_merge: mock_merge.return_value = mock_tasks self.infer_worker.do_work(cfg) - + # 验证_merge_datasets被调用 mock_merge.assert_called_once_with(mock_tasks) @@ -285,10 +285,10 @@ def test_merge_datasets(self): 'models': [{'abbr': 'model2'}], 'datasets': [[{'type': 'dataset_type', 'infer_cfg': {'inferencer': 'inferencer_type'}}]] } - + # 执行测试 result = self.infer_worker._merge_datasets([task1, task2, task3]) - + # 验证结果 assert len(result) == 2 # 应该合并为2个任务 # 第一个任务应该包含合并后的数据集 @@ -302,13 +302,13 @@ def test_update_tasks_cfg_with_attack(self): task = MagicMock() task.datasets = [[MagicMock(abbr='test_dataset')]] tasks = [task] - + cfg = MagicMock() cfg.attack = MagicMock() - + # 执行测试 self.infer_worker._update_tasks_cfg(tasks, cfg) - + # 验证结果 assert cfg.attack.dataset == 'test_dataset' assert task.attack == cfg.attack @@ -318,12 +318,12 @@ def test_update_tasks_cfg_without_attack(self): # 创建测试数据 task = MagicMock() tasks = [task] - + cfg = MagicMock() # 删除attack属性 if hasattr(cfg, 'attack'): delattr(cfg, 'attack') - + # 执行测试 - 不应抛出异常 self.infer_worker._update_tasks_cfg(tasks, cfg) @@ -342,24 +342,24 @@ def test_update_cfg(self, mock_get_config_type): """测试update_cfg方法""" # 设置mock返回值 mock_get_config_type.side_effect = ['MockNaivePartitioner', 'MockOpenICLEvalTask', 'MockLocalRunner'] - + # 创建测试配置 - 使用MockConfigDict cli_args = MagicMock() cli_args.dump_eval_details = True cli_args.dump_extract_rate = True cli_args.debug = True - + cfg = MockConfigDict({ 'models': [{'abbr': 'test_model'}], 'datasets': [{'abbr': 'test_dataset'}], 'work_dir': '/test/workdir', 'cli_args': cli_args }) - + # 执行测试 with patch('os.path.join', return_value='/test/workdir/results/'): result = self.eval_worker.update_cfg(cfg) - + # 验证结果 assert result == cfg assert cfg['eval']['partitioner']['type'] == 'MockNaivePartitioner' @@ -371,7 +371,7 @@ def test_update_cfg(self, mock_get_config_type): assert cfg['eval']['runner']['task']['dump_details'] == True assert cfg['eval']['runner']['task']['cal_extract_rate'] == True assert cfg['eval']['partitioner']['out_dir'] == '/test/workdir/results/' - + # 注意:fill_model_path_if_datasets_need是在_fill_dataset_configs中调用的,不是在Eval.update_cfg中 # 所以这里不应该验证它被调用 @@ -385,23 +385,25 @@ def test_do_work_normal_tasks(self, mock_logger, mock_runners, mock_partitioners mock_partitioners.build.return_value = mock_partitioner mock_tasks = [MagicMock()] mock_partitioner.return_value = mock_tasks - + mock_runner = MagicMock() mock_runners.build.return_value = mock_runner - + # 创建测试配置 - 使用MockConfigDict + # 添加datasets字段以支持cfg.datasets访问 cfg = MockConfigDict({ 'eval': { 'partitioner': {}, 'runner': {} - } + }, + 'datasets': [] }) - + # 模拟_update_tasks_cfg方法 with patch.object(self.eval_worker, '_update_tasks_cfg'): # 执行测试 self.eval_worker.do_work(cfg) - + # 验证结果 mock_partitioners.build.assert_called_once_with(cfg['eval']['partitioner']) mock_partitioner.assert_called_once_with(cfg) @@ -421,23 +423,25 @@ def test_do_work_nested_tasks(self, mock_logger, mock_runners, mock_partitioners mock_task_part2 = [MagicMock()] mock_tasks = [mock_task_part1, mock_task_part2] mock_partitioner.return_value = mock_tasks - + mock_runner = MagicMock() mock_runners.build.return_value = mock_runner - + # 创建测试配置 - 使用MockConfigDict + # 添加datasets字段以支持cfg.datasets访问 cfg = MockConfigDict({ 'eval': { 'partitioner': {}, 'runner': {} - } + }, + 'datasets': [] }) - + # 模拟_update_tasks_cfg方法 with patch.object(self.eval_worker, '_update_tasks_cfg'): # 执行测试 self.eval_worker.do_work(cfg) - + # 验证结果 - runner应该被调用两次,分别处理每个任务部分 assert mock_runner.call_count == 2 mock_runner.assert_any_call(mock_task_part1) @@ -461,13 +465,13 @@ def test_update_cfg_no_summarizer(self, mock_get_config_type): """测试update_cfg方法,没有summarizer配置的情况""" # 设置mock返回值 mock_get_config_type.return_value = 'MockDefaultSummarizer' - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({}) - + # 执行测试 result = self.acc_viz_worker.update_cfg(cfg) - + # 验证结果 assert result == cfg assert cfg['summarizer']['type'] == 'MockDefaultSummarizer' @@ -478,17 +482,17 @@ def test_update_cfg_with_attr(self, mock_get_config_type): """测试update_cfg方法,summarizer有attr属性的情况""" # 设置mock返回值 mock_get_config_type.return_value = 'MockDefaultSummarizer' - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'summarizer': { 'attr': 'accuracy' } }) - + # 执行测试 self.acc_viz_worker.update_cfg(cfg) - + # 验证结果 assert 'attr' not in cfg['summarizer'] @@ -499,15 +503,17 @@ def test_do_work_normal(self, mock_build_from_cfg, mock_logger): # 设置mock对象 mock_summarizer = MagicMock() mock_build_from_cfg.return_value = mock_summarizer - + # 创建测试配置 - 使用MockConfigDict + # 添加datasets字段以支持cfg.datasets访问 cfg = MockConfigDict({ - 'summarizer': {} + 'summarizer': {}, + 'datasets': [] }) - + # 执行测试 self.acc_viz_worker.do_work(cfg) - + # 验证结果 mock_build_from_cfg.assert_called_once_with({'config': cfg}) mock_summarizer.summarize.assert_called_once_with(time_str='20240101_120000') @@ -523,7 +529,7 @@ def test_do_work_subjective(self, mock_build_from_cfg, mock_logger): mock_summarizer3 = MagicMock() # 使用列表而不是生成器,避免StopIteration错误 mock_build_from_cfg.side_effect = [mock_summarizer1, mock_summarizer1, mock_summarizer3] - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'summarizer': { @@ -535,10 +541,10 @@ def test_do_work_subjective(self, mock_build_from_cfg, mock_logger): {'abbr': 'dataset2_1', 'summarizer': {'type': 'summarizer_type2'}} ] }) - + # 执行测试 self.acc_viz_worker.do_work(cfg) - + # 验证结果 - 应该构建多个摘要器 assert mock_build_from_cfg.call_count == 3 # 验证主摘要器被调用时传入了主观分数 @@ -558,7 +564,7 @@ def test_update_cfg_complete(self, mock_get_config_type): """测试update_cfg方法,完整配置的情况""" # 设置mock返回值 mock_get_config_type.side_effect = ['MockDefaultPerfSummarizer', 'MockDefaultPerfMetricCalculator'] - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'summarizer': { @@ -568,10 +574,10 @@ def test_update_cfg_complete(self, mock_get_config_type): 'prompt_db': 'db_path' } }) - + # 执行测试 result = self.perf_viz_worker.update_cfg(cfg) - + # 验证结果 assert result == cfg assert cfg['summarizer']['type'] == 'MockDefaultPerfSummarizer' @@ -586,13 +592,13 @@ def test_update_cfg_minimal(self, mock_get_config_type): """测试update_cfg方法,最小配置的情况""" # 设置mock返回值 mock_get_config_type.side_effect = ['MockDefaultPerfSummarizer', 'MockDefaultPerfMetricCalculator'] - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({}) - + # 执行测试 self.perf_viz_worker.update_cfg(cfg) - + # 验证结果 assert cfg['summarizer']['type'] == 'MockDefaultPerfSummarizer' assert cfg['summarizer']['calculator']['type'] == 'MockDefaultPerfMetricCalculator' @@ -604,15 +610,15 @@ def test_do_work(self, mock_build_from_cfg, mock_logger): # 设置mock对象 mock_summarizer = MagicMock() mock_build_from_cfg.return_value = mock_summarizer - + # 创建测试配置 - 使用MockConfigDict cfg = MockConfigDict({ 'summarizer': {} }) - + # 执行测试 self.perf_viz_worker.do_work(cfg) - + # 验证结果 mock_build_from_cfg.assert_called_once_with({'config': cfg}) mock_summarizer.summarize.assert_called_once() @@ -625,25 +631,25 @@ def test_init(self): mock_cfg = MagicMock() mock_workflow = [MagicMock(), MagicMock()] executor = WorkFlowExecutor(mock_cfg, mock_workflow) - + assert executor.cfg == mock_cfg assert executor.workflow == mock_workflow def test_execute(self): """测试execute方法""" mock_cfg = MagicMock() - + # 创建两个mock worker mock_worker1 = MagicMock() mock_worker2 = MagicMock() mock_workflow = [mock_worker1, mock_worker2] - + # 创建执行器 executor = WorkFlowExecutor(mock_cfg, mock_workflow) - + # 执行测试 executor.execute() - + # 验证结果 - 每个worker的do_work方法都应该被调用 mock_worker1.do_work.assert_called_once_with(mock_cfg) mock_worker2.do_work.assert_called_once_with(mock_cfg) @@ -658,20 +664,20 @@ def test_work_flow_dict(): assert 'viz' in WORK_FLOW assert 'perf' in WORK_FLOW assert 'perf_viz' in WORK_FLOW - + # 验证工作流内容正确 assert Infer in WORK_FLOW['all'] assert Eval in WORK_FLOW['all'] assert AccViz in WORK_FLOW['all'] - + assert Infer in WORK_FLOW['infer'] - + assert Eval in WORK_FLOW['eval'] assert AccViz in WORK_FLOW['eval'] - + assert AccViz in WORK_FLOW['viz'] - + assert Infer in WORK_FLOW['perf'] assert PerfViz in WORK_FLOW['perf'] - + assert PerfViz in WORK_FLOW['perf_viz'] \ No newline at end of file diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_bfcl_v3_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_bfcl_v3_output_handler.py index dad9153e..b55637fb 100644 --- a/tests/UT/openicl/icl_inferencer/output_handler/test_bfcl_v3_output_handler.py +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_bfcl_v3_output_handler.py @@ -209,7 +209,7 @@ def test_get_result_with_function_call_output(self): output.tool_calls = [{"function": "test", "arguments": {}}] output.inference_log = [] - result = handler.get_result(conn, "test_input", output, "test_gold") + result = handler.get_result(conn, "data_abbr", "test_input", output, "test_gold") self.assertEqual(result["success"], True) self.assertEqual(result["uuid"], "test_uuid_integration") @@ -230,7 +230,7 @@ def test_get_result_with_failed_function_call_output(self): output.inference_log = [] output.error_info = "Integration test error" - result = handler.get_result(conn, "test_input", output, "test_gold") + result = handler.get_result(conn, "data_abbr", "test_input", output, "test_gold") self.assertEqual(result["success"], False) self.assertIn("error_info", result) @@ -263,7 +263,7 @@ def test_get_result_perf_mode(self): "throughput": 100 }) - result = handler.get_result(conn, "test_input", output, "test_gold") + result = handler.get_result(conn, "data_abbr", "test_input", output, "test_gold") # In perf_mode, get_result should call get_metrics self.assertIn("latency", result) @@ -274,4 +274,3 @@ def test_get_result_perf_mode(self): if __name__ == '__main__': unittest.main() - diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_gen_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_gen_inferencer_output_handler.py index 77678565..b50b5173 100644 --- a/tests/UT/openicl/icl_inferencer/output_handler/test_gen_inferencer_output_handler.py +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_gen_inferencer_output_handler.py @@ -33,7 +33,7 @@ def test_get_result_perf_mode(self): handler._extract_and_write_arrays = mock.Mock(return_value={"latency": 0.1, "throughput": 100}) - result = handler.get_result(conn, "input", output, "gold") + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertIn("latency", result) handler._extract_and_write_arrays.assert_called_once() @@ -51,10 +51,10 @@ def test_get_result_accuracy_mode_string_output(self): # Test with string output - should work now with UUID generation string_output = "predicted_text" - result = handler.get_result(conn, "input", string_output, "gold") + result = handler.get_result(conn, "data_abbr", "input", string_output, "gold") self.assertEqual(result["success"], True) self.assertIn("uuid", result) - # UUID should be generated (8 characters) + # UUID should be generated (8 characters from uuid.uuid4().hex[:8]) self.assertEqual(len(result["uuid"]), 8) self.assertEqual(result["prediction"], "predicted_text") self.assertEqual(result["origin_prompt"], "input") @@ -72,7 +72,16 @@ def test_get_result_accuracy_mode_output_object(self): output.uuid = "test_uuid" output.get_prediction = mock.Mock(return_value="predicted_text") - result = handler.get_result(conn, "input", output, "gold") + # Mock get_prediction_result to return expected result + handler.get_prediction_result = mock.Mock(return_value={ + "success": True, + "uuid": "test_uuid", + "prediction": "predicted_text", + "origin_prompt": "input", + "gold": "gold" + }) + + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertEqual(result["success"], True) self.assertEqual(result["uuid"], "test_uuid") self.assertEqual(result["prediction"], "predicted_text") @@ -92,7 +101,16 @@ def test_get_result_with_failure(self): output.error_info = "Test error" output.get_prediction = mock.Mock(return_value="") - result = handler.get_result(conn, "input", output, "gold") + # Mock get_prediction_result to return failed result + handler.get_prediction_result = mock.Mock(return_value={ + "success": False, + "uuid": "test_uuid", + "prediction": "", + "origin_prompt": "input", + "gold": "gold" + }) + + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertEqual(result["success"], False) self.assertIn("error_info", result) self.assertEqual(result["error_info"], "Test error") @@ -111,7 +129,16 @@ def test_get_result_with_failure_no_error_info(self): # No error_info attribute output.get_prediction = mock.Mock(return_value="") - result = handler.get_result(conn, "input", output, "gold") + # Mock get_prediction_result to return failed result + handler.get_prediction_result = mock.Mock(return_value={ + "success": False, + "uuid": "test_uuid", + "prediction": "", + "origin_prompt": "input", + "gold": "gold" + }) + + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertEqual(result["success"], False) self.assertFalse(handler.all_success) @@ -127,7 +154,15 @@ def test_get_result_without_gold(self): output.uuid = "test_uuid" output.get_prediction = mock.Mock(return_value="predicted_text") - result = handler.get_result(conn, "input", output, None) + # Mock get_prediction_result to return result without gold + handler.get_prediction_result = mock.Mock(return_value={ + "success": True, + "uuid": "test_uuid", + "prediction": "predicted_text", + "origin_prompt": "input" + }) + + result = handler.get_result(conn, "data_abbr", "input", output, None) self.assertNotIn("gold", result) conn.close() @@ -138,7 +173,7 @@ def test_get_result_string_output_without_gold(self): conn = sqlite3.connect(":memory:") string_output = "predicted_text" - result = handler.get_result(conn, "input", string_output, None) + result = handler.get_result(conn, "data_abbr", "input", string_output, None) self.assertEqual(result["success"], True) self.assertIn("uuid", result) @@ -155,8 +190,8 @@ def test_get_result_string_output_uuid_uniqueness(self): conn = sqlite3.connect(":memory:") string_output = "predicted_text" - result1 = handler.get_result(conn, "input1", string_output, "gold1") - result2 = handler.get_result(conn, "input2", string_output, "gold2") + result1 = handler.get_result(conn, "data_abbr", "input1", string_output, "gold1") + result2 = handler.get_result(conn, "data_abbr", "input2", string_output, "gold2") # UUIDs should be different for different calls self.assertNotEqual(result1["uuid"], result2["uuid"]) @@ -178,7 +213,7 @@ def test_get_result_perf_mode_with_output_object(self): handler._extract_and_write_arrays = mock.Mock(return_value={"latency": 0.1, "throughput": 100}) - result = handler.get_result(conn, "input", output, "gold") + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertIn("latency", result) handler._extract_and_write_arrays.assert_called_once() @@ -199,7 +234,7 @@ def test_get_result_perf_mode_with_string_output(self): # and now properly generates UUID for string output string_output = "predicted_text" - result = handler.get_result(conn, "input", string_output, "gold") + result = handler.get_result(conn, "data_abbr", "input", string_output, "gold") self.assertEqual(result["success"], True) self.assertIn("uuid", result) # UUID should be generated (8 characters) @@ -224,7 +259,16 @@ def test_get_result_accuracy_mode_failure_no_error_info(self): if hasattr(output, "error_info"): delattr(output, "error_info") - result = handler.get_result(conn, "input", output, "gold") + # Mock get_prediction_result to return failed result + handler.get_prediction_result = mock.Mock(return_value={ + "success": False, + "uuid": "test_uuid", + "prediction": "", + "origin_prompt": "input", + "gold": "gold" + }) + + result = handler.get_result(conn, "data_abbr", "input", output, "gold") self.assertEqual(result["success"], False) self.assertFalse(handler.all_success) # Should not have error_info when it doesn't exist @@ -235,4 +279,3 @@ def test_get_result_accuracy_mode_failure_no_error_info(self): if __name__ == '__main__': unittest.main() - diff --git a/tests/UT/openicl/icl_inferencer/output_handler/test_ppl_inferencer_output_handler.py b/tests/UT/openicl/icl_inferencer/output_handler/test_ppl_inferencer_output_handler.py index f4525d59..316ae935 100644 --- a/tests/UT/openicl/icl_inferencer/output_handler/test_ppl_inferencer_output_handler.py +++ b/tests/UT/openicl/icl_inferencer/output_handler/test_ppl_inferencer_output_handler.py @@ -170,7 +170,7 @@ def test_get_result_inherited_behavior(self): output.origin_prompt_logprobs = [] output.get_prediction = mock.Mock(return_value="A") - result = handler.get_result(conn, "input", output, gold="A") + result = handler.get_result(conn, "data_abbr", "input", output, gold="A") self.assertTrue(result["success"]) self.assertEqual(result["uuid"], "test_uuid") @@ -195,7 +195,7 @@ def test_get_result_with_failure(self): output.error_info = "Test error" output.get_prediction = mock.Mock(return_value=None) - result = handler.get_result(conn, "input", output, gold="A") + result = handler.get_result(conn, "data_abbr", "input", output, gold="A") self.assertFalse(result["success"]) self.assertIn("error_info", result) @@ -216,7 +216,7 @@ def test_get_result_perf_mode(self): handler._extract_and_write_arrays = mock.Mock(return_value={"latency": 0.1, "throughput": 100}) - result = handler.get_result(conn, "input", output, gold="A") + result = handler.get_result(conn, "data_abbr", "input", output, gold="A") self.assertIn("latency", result) self.assertIn("throughput", result) @@ -227,4 +227,3 @@ def test_get_result_perf_mode(self): if __name__ == '__main__': unittest.main() - From 01c16a593b0ef2544e82d7e4cef96ae032d21828 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 09:08:39 +0800 Subject: [PATCH 51/59] delete unused dataset config --- .../configs/datasets/gedit/gedit_gen.py | 44 ------------------- 1 file changed, 44 deletions(-) delete mode 100644 ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py deleted file mode 100644 index 57509dee..00000000 --- a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py +++ /dev/null @@ -1,44 +0,0 @@ -from ais_bench.benchmark.openicl.icl_prompt_template.icl_prompt_template_mm import MMPromptTemplate -from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever -from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer -from ais_bench.benchmark.datasets.g_edit import GEditDataset, GEditEvaluator - - -gedit_reader_cfg = dict( - input_columns=['question', 'image'], - output_column='task_type' -) - - -gedit_infer_cfg = dict( - prompt_template=dict( - type=MMPromptTemplate, - template=dict( - round=[ - dict(role="HUMAN", prompt_mm={ - "text": {"type": "text", "text": "{question}"}, - "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, - }) - ] - ) - ), - retriever=dict(type=ZeroRetriever), - inferencer=dict(type=LMMGenInferencer) -) - -gedit_eval_cfg = dict( - evaluator=dict(type=GEditEvaluator) -) - -gedit_datasets = [ - dict( - abbr='gedit', - type=GEditDataset, - path='ais_bench/datasets/GEdit-Bench', # 数据集路径,使用相对路径时相对于源码根路径,支持绝对路径 - split_count=1, - split_index=0, - reader_cfg=gedit_reader_cfg, - infer_cfg=gedit_infer_cfg, - eval_cfg=gedit_eval_cfg - ) -] From 0a5bc1fe0d0688fc0fa4750589a05b33bdeb6c44 Mon Sep 17 00:00:00 2001 From: Hanye <1037452625@qq.com> Date: Wed, 4 Mar 2026 14:17:46 +0800 Subject: [PATCH 52/59] Update ais_bench/benchmark/datasets/base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ais_bench/benchmark/datasets/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index d014ff63..ec12612a 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from pickle import DICT from typing import List, Dict, Optional, Union, Type from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm From d0b9df16d8d324eb45fd52b0329037a3d28ccaf4 Mon Sep 17 00:00:00 2001 From: Hanye <1037452625@qq.com> Date: Wed, 4 Mar 2026 14:18:57 +0800 Subject: [PATCH 53/59] Update ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../gen_inferencer_output_handler.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 705340c8..42619017 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -41,9 +41,21 @@ def get_prediction_result( Returns: dict: Prediction result """ - for item in input[0]['prompt']: - if item.get('image_url') and len(item['image_url']['url']) > 256: - item['image_url']['url'] = item['image_url']['url'][:256] + " ..." + if ( + isinstance(input, list) + and len(input) > 0 + and isinstance(input[0], dict) + and isinstance(input[0].get("prompt"), list) + ): + for item in input[0]["prompt"]: + if not isinstance(item, dict): + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): + continue + url = image_url.get("url") + if isinstance(url, str) and len(url) > 256: + image_url["url"] = url[:256] + " ..." result_data = { "success": ( output.success if isinstance(output, Output) else True From fefb51a3a19eb9f28d5d9adf190cd8dfe00e500d Mon Sep 17 00:00:00 2001 From: Hanye <1037452625@qq.com> Date: Wed, 4 Mar 2026 14:19:19 +0800 Subject: [PATCH 54/59] Update ais_bench/benchmark/datasets/utils/llm_judge.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ais_bench/benchmark/datasets/utils/llm_judge.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index 07c8e0df..78c3a190 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -28,9 +28,12 @@ def _load_from_predictions(self, prediction_path: str): Returns: Dataset: The merged dataset with predictions. """ - if os.path.exists(prediction_path): - preds = load_jsonl(prediction_path) - preds.sort(key=lambda x: x.get('id',0)) + if not os.path.exists(prediction_path): + logger.warning(f"Prediction file does not exist: {prediction_path}") + return [] + + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id', 0)) return preds From 8b91ac60d5acb8aeaf76ad8606b39f81d10b105e Mon Sep 17 00:00:00 2001 From: Hanye <1037452625@qq.com> Date: Wed, 4 Mar 2026 14:20:00 +0800 Subject: [PATCH 55/59] Update ais_bench/benchmark/utils/file/file.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ais_bench/benchmark/utils/file/file.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ais_bench/benchmark/utils/file/file.py b/ais_bench/benchmark/utils/file/file.py index 47f048dc..f15a26c4 100644 --- a/ais_bench/benchmark/utils/file/file.py +++ b/ais_bench/benchmark/utils/file/file.py @@ -241,9 +241,9 @@ def load_jsonl(path: str) -> List[dict]: """ preds = [] with open(path, "rb") as f: - mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - for line in iter(mm.readline, b""): - preds.append(orjson.loads(line)) + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + for line in iter(mm.readline, b""): + preds.append(orjson.loads(line)) return preds def dump_jsonl(data: List[dict], path: str): From a8454797fd991114bd522f9eb93e2216eab3ea68 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 15:20:00 +0800 Subject: [PATCH 56/59] fix --- ais_bench/benchmark/datasets/base.py | 92 +++++++------------ ais_bench/benchmark/datasets/g_edit.py | 3 +- .../benchmark/datasets/utils/lmm_judge.py | 3 +- .../output_handler/base_handler.py | 1 + .../gen_inferencer_output_handler.py | 6 +- .../lmm_gen_inferencer_output_handler.py | 20 +++- 6 files changed, 57 insertions(+), 68 deletions(-) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index ec12612a..3a98ce4c 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -14,6 +14,9 @@ disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination +JDG_DATASET_LOAD_BATCH_SIZE = 10 + + class BaseDataset: def __init__(self, @@ -138,75 +141,50 @@ def load(self, predictions_path: str, **kwargs): dataset_content = self.dataset_instance.dataset["test"] - # 加载被测模型的推理结果(排序后) predictions: list = self._load_from_predictions(predictions_path) - # 为数据集添加 model_answer 列 - batch_size = 10 # 批处理大小,可以根据实际情况调整 + if isinstance(dataset_content, Dataset): + datasets_to_process = [dataset_content] + elif isinstance(dataset_content, DatasetDict): + datasets_to_process = [dataset_content[key] for key in dataset_content] + else: + raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + + return self._process_predictions(datasets_to_process, predictions) + + def _process_predictions(self, datasets_to_process: List[Dataset], predictions: list) -> Dataset: dataset_batches = [] current_batch = [] - if isinstance(dataset_content, Dataset): - with ThreadPoolExecutor() as executor: - futures = [] + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for dataset in datasets_to_process: for item in predictions: - future = executor.submit(self._process_single_item, dataset_content, item) + future = executor.submit(self._process_single_item, dataset, item) futures.append(future) - with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: - for i, future in enumerate(as_completed(futures)): - result = future.result() - current_batch.append(result) - - # 当批次达到指定大小时,转换为Dataset并添加到批次列表 - if len(current_batch) >= batch_size: - dataset_batches.append(Dataset.from_list(current_batch)) - current_batch = [] - - pbar.update(1) - self.update_task_state( - { - "total_count": len(futures), - "progress_description": "Processing predictions", - "finish_count": i + 1, - } - ) - pbar.refresh() - elif isinstance(dataset_content, DatasetDict): - with ThreadPoolExecutor() as executor: - futures = [] - for key in dataset_content: - for item in predictions: - future = executor.submit(self._process_single_item, dataset_content[key], item) - futures.append(future) - - with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: - for i, future in enumerate(as_completed(futures)): - result = future.result() - current_batch.append(result) - - # 当批次达到指定大小时,转换为Dataset并添加到批次列表 - if len(current_batch) >= batch_size: - dataset_batches.append(Dataset.from_list(current_batch)) - current_batch = [] - - pbar.update(1) - self.update_task_state( - { - "total_count": len(futures), - "progress_description": "Processing predictions", - "finish_count": i + 1, - } - ) - pbar.refresh() - else: - raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + with tqdm(total=len(futures), desc="Processing predictions", unit="item") as pbar: + for i, future in enumerate(as_completed(futures)): + result = future.result() + current_batch.append(result) + + if len(current_batch) >= JDG_DATASET_LOAD_BATCH_SIZE: + dataset_batches.append(Dataset.from_list(current_batch)) + current_batch = [] + + pbar.update(1) + self.update_task_state( + { + "total_count": len(futures), + "progress_description": "Processing predictions", + "finish_count": i + 1, + } + ) + pbar.refresh() - # 处理最后一个不完整的批次 if current_batch: dataset_batches.append(Dataset.from_list(current_batch)) - # 合并所有批次的Dataset if dataset_batches: if len(dataset_batches) == 1: return dataset_batches[0] diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index d23e63ec..0302994e 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -70,10 +70,9 @@ def process_example_to_dataset(example): data_dict = {key: [example[key]] for key in example.keys()} return Dataset.from_dict(data_dict) - max_workers = 4 # Adjust based on system resources processed_datasets = [None] * len(dataset) - with ThreadPoolExecutor(max_workers=max_workers) as executor: + with ThreadPoolExecutor(max_workers=8) as executor: # 提交所有任务 with tqdm(total=len(dataset), desc=f"Convert GEdit dataset to base64, split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: futures = {} diff --git a/ais_bench/benchmark/datasets/utils/lmm_judge.py b/ais_bench/benchmark/datasets/utils/lmm_judge.py index f35ecf53..5380187a 100644 --- a/ais_bench/benchmark/datasets/utils/lmm_judge.py +++ b/ais_bench/benchmark/datasets/utils/lmm_judge.py @@ -72,8 +72,7 @@ def process_image(index, pred_item): return pred_item # 使用并行处理加速图片处理 - max_workers = min(8, os.cpu_count()) # 根据CPU核心数调整 - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: # 使用tqdm显示进度 processed_preds = list(tqdm( executor.map(lambda x: process_image(x[0], x[1]), enumerate(preds)), diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py index 9d2a450b..3f305541 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py @@ -22,6 +22,7 @@ DB_REF_KEY = "__db_ref__" DB_DATA_DIR = "db_data" +BASE64_MAX_DISPLAY_LEN = 256 class BaseInferencerOutputHandler: diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 42619017..5beeffcb 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -3,7 +3,7 @@ import sqlite3 import uuid -from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler, BASE64_MAX_DISPLAY_LEN from ais_bench.benchmark.models.output import Output from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError @@ -54,8 +54,8 @@ def get_prediction_result( if not isinstance(image_url, dict): continue url = image_url.get("url") - if isinstance(url, str) and len(url) > 256: - image_url["url"] = url[:256] + " ..." + if isinstance(url, str) and len(url) > BASE64_MAX_DISPLAY_LEN: + image_url["url"] = url[:BASE64_MAX_DISPLAY_LEN] + " ..." result_data = { "success": ( output.success if isinstance(output, Output) else True diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py index cdc4a21b..db31f349 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -3,7 +3,7 @@ import uuid from pathlib import Path -from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler, BASE64_MAX_DISPLAY_LEN from ais_bench.benchmark.models.output import LMMOutput from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError @@ -47,9 +47,21 @@ def get_prediction_result( save_dir = Path(self.output_path) / f"{data_abbr}_out_file" if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) - for item in input[0]['prompt']: - if item.get('image_url') and len(item['image_url']['url']) > 256: - item['image_url']['url'] = item['image_url']['url'][:256] + " ..." + if ( + isinstance(input, list) + and len(input) > 0 + and isinstance(input[0], dict) + and isinstance(input[0].get("prompt"), list) + ): + for item in input[0]["prompt"]: + if not isinstance(item, dict): + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): + continue + url = image_url.get("url") + if isinstance(url, str) and len(url) > BASE64_MAX_DISPLAY_LEN: + image_url["url"] = url[:BASE64_MAX_DISPLAY_LEN] + " ..." result_data = { "success": ( output.success if isinstance(output, LMMOutput) else True From fb442eb8589db00aee7c9236f8b7bcef1542a66a Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 15:29:44 +0800 Subject: [PATCH 57/59] fix review --- .../configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py index 75081615..4a156775 100644 --- a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen_0_shot_llmjudge.py @@ -97,8 +97,8 @@ use_timestamp=False, retry=2, api_key="", - host_ip="192.168.9.123", - host_port=5103, + host_ip="localhost", + host_port=8080, url="", max_out_len=512, batch_size=16, From f886815785c9f5903608ab41b1d5fe0af11f38ae Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 15:36:56 +0800 Subject: [PATCH 58/59] fix review --- tests/pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytest.ini b/tests/pytest.ini index c009743e..9d919f5d 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -46,6 +46,7 @@ omit = */datasets/omnidocbench/* */ais_bench/benchmark/datasets/humanevalx/* */ais_bench/benchmark/datasets/livecodebench/testing_util.py + */ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py [coverage:report] # 覆盖率报告配置 From d80b4280345afaa5dd4571c2055f142da73227ac Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 15:51:44 +0800 Subject: [PATCH 59/59] fix review --- ais_bench/benchmark/datasets/g_edit.py | 2 +- .../local_models/qwen_image_edit_mindie_sd.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py index 0302994e..18300705 100644 --- a/ais_bench/benchmark/datasets/g_edit.py +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -13,7 +13,7 @@ from ais_bench.benchmark.datasets.base import BaseDataset from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START -GEDIT_COUNT = 1 +GEDIT_COUNT = 1212 # total 1212 cases, could change for quick test class GEditEvaluator(BaseEvaluator): def score(self, predictions, references): diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py index 505d6346..60eeb2c3 100644 --- a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -14,7 +14,8 @@ from ais_bench.benchmark.registry import MODELS from ais_bench.benchmark.utils.prompt import PromptList from ais_bench.benchmark.utils.logging import AISLogger -from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError +from ais_bench.benchmark.utils.logging.error_codes import MODEL_CODES from ais_bench.benchmark.models.local_models.huggingface_above_v4_33 import (_convert_chat_messages, _get_meta_template, ) @@ -131,6 +132,13 @@ def __init__(self, self.seed = infer_kwargs.get('seed', DEFAULT_SEED) self.num_images_per_prompt = infer_kwargs.get('num_images_per_prompt', DEFAULT_NUM_IMAGES_PER_PROMPT) self.quant_desc_path = infer_kwargs.get('quant_desc_path', DEFAULT_QUANT_DESC_PATH) + self.logger.info( + f"load model: {self.path}; torch_dtype: {self.torch_dtype}; " + f"device: {self.device}; device_id: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}; " + f"num_inference_steps: {self.num_inference_steps}; true_cfg_scale: {self.true_cfg_scale}; " + f"guidance_scale: {self.guidance_scale}; seed: {self.seed}; num_images_per_prompt: {self.num_images_per_prompt}; " + f"quant_desc_path: {self.quant_desc_path}" + ) # 加载模型 self._load_model() @@ -241,7 +249,7 @@ def _generate(self, input) -> List[Image]: # 如果没有图像输入,使用默认图像 if not images: - raise ValueError("QwenImageEditModel requires image input") + raise AISBenchRuntimeError(MODEL_CODES.UNKNOWN_ERROR, "QwenImageEditModel requires image input, but can't get image info from input.") # 执行推理 results = [] @@ -271,7 +279,7 @@ def _generate(self, input) -> List[Image]: torch.npu.synchronize() end_time = time.time() infer_time = end_time - start_time - self.logger.info(f"推理完成,耗时: {infer_time:.2f}秒") + self.logger.info(f"Current image finish generated, cost: {infer_time:.2f} second.") return output