From c09046eb7a1d0177e11fbf0eb20cf3f78d594c00 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 16:18:30 +0800 Subject: [PATCH 1/2] add cli models openicl tasks --- ais_bench/benchmark/cli/argument_parser.py | 2 +- ais_bench/benchmark/cli/config_manager.py | 3 +- ais_bench/benchmark/cli/workers.py | 176 ++++++++- .../benchmark/models/local_models/__init__.py | 0 .../benchmark/models/local_models/base.py | 22 +- .../local_models/qwen_image_edit_mindie_sd.py | 343 ++++++++++++++++++ ais_bench/benchmark/models/output.py | 64 +++- .../icl_inferencer/icl_lmm_gen_inferencer.py | 67 ++++ .../icl_inferencer/output_handler/__init__.py | 0 .../output_handler/base_handler.py | 17 +- .../output_handler/bfcl_v3_output_handler.py | 10 +- .../gen_inferencer_output_handler.py | 19 +- .../lmm_gen_inferencer_output_handler.py | 84 +++++ .../ppl_inferencer_output_handler.py | 20 +- .../icl_prompt_template_mm.py | 3 +- .../benchmark/tasks/openicl_api_infer.py | 3 +- ais_bench/benchmark/tasks/openicl_eval.py | 8 +- ais_bench/benchmark/tasks/openicl_infer.py | 2 +- 18 files changed, 807 insertions(+), 36 deletions(-) create mode 100644 ais_bench/benchmark/models/local_models/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py diff --git a/ais_bench/benchmark/cli/argument_parser.py b/ais_bench/benchmark/cli/argument_parser.py index 38d79851..38a02e7b 100644 --- a/ais_bench/benchmark/cli/argument_parser.py +++ b/ais_bench/benchmark/cli/argument_parser.py @@ -61,7 +61,7 @@ def _base_parser(self): help='Running mode. Choose "perf" for performance evaluation, "infer" to run inference only, ' '"eval" to evaluate existing inference results, or "viz" to visualize the results. ' 'The default mode is "all", which runs all steps.', - choices=['all', 'infer', 'eval', 'viz', 'perf', 'perf_viz'], + choices=['all', 'infer', 'eval', 'viz', 'perf', 'perf_viz', 'judge', 'infer_judge'], default='all', type=str ) diff --git a/ais_bench/benchmark/cli/config_manager.py b/ais_bench/benchmark/cli/config_manager.py index f7c3ab5f..dde76527 100644 --- a/ais_bench/benchmark/cli/config_manager.py +++ b/ais_bench/benchmark/cli/config_manager.py @@ -1,4 +1,3 @@ - import os import os.path as osp import tabulate @@ -104,7 +103,7 @@ def load_config(self, workflow): self._update_cfg_of_workflow(workflow) self._dump_and_reload_config() return self.cfg - + def _fill_dataset_configs(self): for dataset_cfg in self.cfg["datasets"]: fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index ca997164..4c28f174 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,5 +1,7 @@ +import os import os.path as osp import copy +import shutil from abc import ABC, abstractmethod from collections import defaultdict @@ -8,12 +10,15 @@ from ais_bench.benchmark.registry import PARTITIONERS, RUNNERS, build_from_cfg from ais_bench.benchmark.utils.config.run import get_config_type from ais_bench.benchmark.utils.logging.logger import AISLogger +from ais_bench.benchmark.utils.logging.exceptions import PredictionInvalidException +from ais_bench.benchmark.utils.logging.error_codes import TMAN_CODES from ais_bench.benchmark.partitioners import NaivePartitioner from ais_bench.benchmark.runners import LocalRunner from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need +from ais_bench.benchmark.utils.file.file import load_jsonl, dump_jsonl logger = AISLogger() @@ -108,6 +113,133 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task.attack = cfg.attack +class JudgeInfer(BaseWorker): + def update_cfg(self, cfg: ConfigDict) -> None: + def get_task_type() -> str: + if cfg["datasets"][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service": + return get_config_type(OpenICLApiInferTask) + else: + return get_config_type(OpenICLInferTask) + + new_cfg = dict( + judge_infer=dict( + partitioner=dict(type=get_config_type(NaivePartitioner)), + runner=dict( + max_num_workers=self.args.max_num_workers, + max_workers_per_gpu=self.args.max_workers_per_gpu, + debug=self.args.debug, + task=dict(type=get_task_type()), + type=get_config_type(LocalRunner), + ), + ), + ) + + cfg.merge_from_dict(new_cfg) + if cfg.cli_args.debug: + cfg.judge_infer.runner.debug = True + cfg.judge_infer.partitioner["out_dir"] = osp.join(cfg["work_dir"], "predictions/") + return cfg + + def do_work(self, cfg: ConfigDict): + partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner) + logger.info("Starting inference tasks...") + self._cfg_pre_process(cfg) + tasks = partitioner(cfg) + + # delete the tasks without judge_infer_cfg + new_tasks = [] + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + new_tasks.append(task) + tasks = new_tasks + if len(tasks) == 0: + return + + # update tasks cfg before run + self._update_tasks_cfg(tasks, cfg) + + if ( + cfg.get("cli_args", {}).get("merge_ds", False) + or cfg.get("cli_args", {}).get("mode") == "perf" # performance mode will enable merge datasets by default + ): + logger.info("Merging datasets with the same model and inferencer...") + tasks = self._merge_datasets(tasks) + + runner = RUNNERS.build(cfg.judge_infer.runner) + runner(tasks) + self._result_post_process(tasks, cfg) + logger.info("Inference tasks completed.") + + def _merge_datasets(self, tasks): + # merge datasets with the same model, dataset type and inferencer + task_groups = defaultdict(list) + for task in tasks: + key = ( + task["models"][0]["abbr"] # same model + + "_" + + str(task['datasets'][0][0]['type']) # same dataset type + + "_" + + str(task["datasets"][0][0]["infer_cfg"]["inferencer"]) # same inferencer with the same args + ) + task_groups[key].append(task) + new_tasks = [] + for key, task_group in task_groups.items(): + new_task = copy.deepcopy(task_group[0]) + if len(task_group) > 1: + for t in task_group[1:]: + new_task["datasets"][0].extend(t["datasets"][0]) + new_tasks.append(new_task) + return new_tasks + + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + self.org_dataset_abbrs = {} + def change_judge_dataset_abbr(item): + if item.get("judge_infer_cfg"): + org_dataset_abbr = item["abbr"] + new_dataset_abbr = f'{item["abbr"]}-{item["judge_infer_cfg"]["judge_model"]["abbr"]}' + item["abbr"] = new_dataset_abbr + self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + if cfg.get('model_dataset_combinations', None) is not None: + for item in cfg.model_dataset_combinations: + for dataset in item["datasets"]: + change_judge_dataset_abbr(dataset) + for dataset in cfg.datasets: + change_judge_dataset_abbr(dataset) + return cfg + + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): + # update parameters to correct sub cfg + if hasattr(cfg, "attack"): + for task in tasks: + cfg.attack.dataset = task.datasets[0][0].abbr + task.attack = cfg.attack + + # update judge cfgs to model cfgs and data + for task in tasks: + task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{self.org_dataset_abbrs[task["datasets"][0][0]["abbr"]]}.jsonl') + if not osp.exists(task["datasets"][0][0]["predictions_path"]): + raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") + model_abbr = task["models"][0]["abbr"] + task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") + task["models"][0]["abbr"] = model_abbr + task["datasets"][0][0]["type"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_dataset_type") + task["datasets"][0][0]["reader_cfg"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_reader_cfg") + task["datasets"][0][0]["infer_cfg"] = task["datasets"][0][0].pop("judge_infer_cfg") + + def _result_post_process(self, tasks, cfg: ConfigDict): + # Reconstruct the judge infer predictions to normal predictions format + for task in tasks: + model_org_prediction_path = task["datasets"][0][0]["predictions_path"] + model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)} + judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + judge_preds: list = load_jsonl(judge_org_prediction_path) + for i, pred in enumerate(judge_preds): + uuid = pred["gold"] + judge_preds[i]["id"] = model_preds[uuid]["id"] + os.remove(judge_org_prediction_path) + dump_jsonl(judge_preds, judge_org_prediction_path) + + class Eval(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: new_cfg = dict( @@ -136,9 +268,11 @@ def update_cfg(self, cfg: ConfigDict) -> None: def do_work(self, cfg: ConfigDict): partitioner = PARTITIONERS.build(cfg.eval.partitioner) logger.info("Starting evaluation tasks...") + self._cfg_pre_process(cfg) + tasks = partitioner(cfg) - # update tasks cfg before run + # Update tasks cfg before run self._update_tasks_cfg(tasks, cfg) runner = RUNNERS.build(cfg.eval.runner) @@ -150,9 +284,28 @@ def do_work(self, cfg: ConfigDict): runner(tasks) logger.info("Evaluation tasks completed.") + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + self.org_dataset_abbrs = {} + def change_eval_dataset_abbr(item): + if item.get("judge_infer_cfg"): + org_dataset_abbr = item["abbr"] + new_dataset_abbr = f'{item["abbr"]}-{item["judge_infer_cfg"]["judge_model"]["abbr"]}' + item["abbr"] = new_dataset_abbr + self.org_dataset_abbrs[new_dataset_abbr] = org_dataset_abbr + if cfg.get('model_dataset_combinations', None) is not None: + for item in cfg.model_dataset_combinations: + for dataset in item["datasets"]: + change_eval_dataset_abbr(dataset) + for dataset in cfg.datasets: + change_eval_dataset_abbr(dataset) + return cfg + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): - # update parameters to correct sub cfg - pass + # Replace default model config to judge model config + self.judge_result_paths = {} + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + task["datasets"][0][0].pop("judge_infer_cfg") class AccViz(BaseWorker): @@ -171,6 +324,7 @@ def update_cfg(self, cfg: ConfigDict) -> None: def do_work(self, cfg: ConfigDict) -> int: logger.info("Summarizing evaluation results...") summarizer_cfg = cfg.get("summarizer", {}) + cfg = self._cfg_pre_process(cfg) # For subjective summarizer if summarizer_cfg.get("function", None): @@ -203,6 +357,13 @@ def do_work(self, cfg: ConfigDict) -> int: summarizer = build_from_cfg(summarizer_cfg) summarizer.summarize(time_str=self.args.cfg_time_str) + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + for i, dataset in enumerate(cfg.datasets): + if dataset.get("judge_infer_cfg"): + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' + cfg.datasets[i].pop("judge_infer_cfg") + return cfg + class PerfViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: @@ -233,9 +394,11 @@ def do_work(self, cfg: ConfigDict) -> int: WORK_FLOW = dict( - all=[Infer, Eval, AccViz], + all=[Infer, JudgeInfer, Eval, AccViz], infer=[Infer], - eval=[Eval, AccViz], + judge=[JudgeInfer], + infer_judge=[Infer, JudgeInfer], + eval=[JudgeInfer, Eval, AccViz], viz=[AccViz], perf=[Infer, PerfViz], perf_viz=[PerfViz], @@ -249,4 +412,5 @@ def __init__(self, cfg, workflow) -> None: def execute(self) -> None: for worker in self.workflow: - worker.do_work(self.cfg) + cfg = copy.deepcopy(self.cfg) + worker.do_work(cfg) diff --git a/ais_bench/benchmark/models/local_models/__init__.py b/ais_bench/benchmark/models/local_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/base.py b/ais_bench/benchmark/models/local_models/base.py index b2766aab..74d43283 100644 --- a/ais_bench/benchmark/models/local_models/base.py +++ b/ais_bench/benchmark/models/local_models/base.py @@ -57,7 +57,7 @@ def __init__(self, self.is_synthetic = False @abstractmethod - def _generate(self, input, max_out_len: int) -> List[str]: + def generate(self, inputs, max_out_len: int) -> List[str]: """Generate result given a input. Args: @@ -133,17 +133,6 @@ def parse_template(self, prompt_template: PromptType, mode: str) -> str: """ return self.template_parser.parse_template(prompt_template, mode) - def generate_from_template(self, templates: List[PromptType], **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates, mode='gen') - max_out_lens = kwargs.get("max_out_lens", [None] * len(templates)) - return self.generate(inputs, max_out_lens, **kwargs) - def get_token_len_from_template( self, templates: Union[PromptType, List[PromptType]], @@ -204,6 +193,15 @@ def sync_inputs(self, inputs: str) -> str: def to(self, device): self.model.to(device) +class BaseLMModel(BaseModel): + """Base class for language models. + """ + def generate(self, inputs, outputs, **kwargs) -> List[str]: + raise AISBenchNotImplementedError( + MODEL_CODES.UNKNOWN_ERROR, + f'{self.__class__.__name__} does not supported' + ' to be called in base classes') + class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py new file mode 100644 index 00000000..60eeb2c3 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -0,0 +1,343 @@ +# flake8: noqa +# yapf: disable +import os +import time +from typing import Dict, List, Optional, Union + +import torch +import torch_npu +import base64 +import io +from PIL import Image + +from ais_bench.benchmark.models.local_models.base import BaseLMModel +from ais_bench.benchmark.registry import MODELS +from ais_bench.benchmark.utils.prompt import PromptList +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError +from ais_bench.benchmark.utils.logging.error_codes import MODEL_CODES +from ais_bench.benchmark.models.local_models.huggingface_above_v4_33 import (_convert_chat_messages, + _get_meta_template, + ) + +# 解决 diffuser 0.35.1 torch2.1 报错 +def custom_op( + name, + fn=None, + /, + *, + mutates_args, + device_types=None, + schema=None, + tags=None, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +def register_fake( + op, + fn=None, + /, + *, + lib=None, + _stacklevel: int = 1, + allow_override: bool = False, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +if hasattr(torch, 'library'): + torch.library.custom_op = custom_op + torch.library.register_fake = register_fake + +# 导入 qwen_image_edit 相关模块 +try: + from ais_bench.third_party.mindie_sd.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel + from ais_bench.third_party.mindie_sd.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline + from mindiesd import CacheConfig, CacheAgent +except ImportError as e: + raise ImportError(f"请确保 qwenimage_edit 模块在 Python 路径中: {e}") + +PromptType = Union[PromptList, str] + +# 模型推理相关配置常量 +DEFAULT_MODEL_PATH = "/home/yanhe/models/Qwen-Image-Edit-2509/" +DEFAULT_TORCH_DTYPE = "bfloat16" +DEFAULT_DEVICE = "npu" +DEFAULT_DEVICE_ID = 0 +DEFAULT_NUM_INFERENCE_STEPS = 1 # 40 +DEFAULT_TRUE_CFG_SCALE = 4.0 +DEFAULT_GUIDANCE_SCALE = 1.0 +DEFAULT_SEED = 0 +DEFAULT_NUM_IMAGES_PER_PROMPT = 1 +DEFAULT_QUANT_DESC_PATH = None + +# 缓存配置开关 +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) + + +@MODELS.register_module() +class QwenImageEditModel(BaseLMModel): + """Model wrapper for Qwen-Image-Edit-2509 models. + + Args: + path (str): The path to the model. + model_kwargs (dict): Additional model arguments. + sample_kwargs (dict): Additional sampling arguments. + vision_kwargs (dict): Additional vision arguments. + meta_template (Optional[Dict]): The model's meta prompt template. + """ + + def __init__(self, + path: str = DEFAULT_MODEL_PATH, + device_kwargs: dict = dict(), + infer_kwargs: dict = dict(), + meta_template: Optional[Dict] = None, + **other_kwargs): + self.logger = AISLogger() + self.path = path + self.max_out_len = other_kwargs.get('max_out_len', None) + self.template_parser = _get_meta_template(meta_template) + + # 设备配置 + self.device = device_kwargs.get('device', DEFAULT_DEVICE) + #self.device_id = device_kwargs.get('device_id', DEFAULT_DEVICE_ID) + # 在这里声明环境变量 + self.logger.debug(f"device id from kwargs: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}") + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = f"{device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}" + self.device_id = DEFAULT_DEVICE_ID + self.device_str = f"{self.device}:{DEFAULT_DEVICE_ID}" + self.logger.debug(f"device_str: {self.device_str}; device_id: {self.device_id}") + self.logger.debug(f"ASCEND_RT_VISIBLE_DEVICES: {os.getenv('ASCEND_RT_VISIBLE_DEVICES')}") + + # 模型配置 + self.torch_dtype = other_kwargs.get('torch_dtype', DEFAULT_TORCH_DTYPE) + self.torch_dtype = torch.bfloat16 if self.torch_dtype == "bfloat16" else torch.float32 + + # 推理配置 + self.num_inference_steps = infer_kwargs.get('num_inference_steps', DEFAULT_NUM_INFERENCE_STEPS) + self.true_cfg_scale = infer_kwargs.get('true_cfg_scale', DEFAULT_TRUE_CFG_SCALE) + self.guidance_scale = infer_kwargs.get('guidance_scale', DEFAULT_GUIDANCE_SCALE) + self.seed = infer_kwargs.get('seed', DEFAULT_SEED) + self.num_images_per_prompt = infer_kwargs.get('num_images_per_prompt', DEFAULT_NUM_IMAGES_PER_PROMPT) + self.quant_desc_path = infer_kwargs.get('quant_desc_path', DEFAULT_QUANT_DESC_PATH) + self.logger.info( + f"load model: {self.path}; torch_dtype: {self.torch_dtype}; " + f"device: {self.device}; device_id: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}; " + f"num_inference_steps: {self.num_inference_steps}; true_cfg_scale: {self.true_cfg_scale}; " + f"guidance_scale: {self.guidance_scale}; seed: {self.seed}; num_images_per_prompt: {self.num_images_per_prompt}; " + f"quant_desc_path: {self.quant_desc_path}" + ) + + # 加载模型 + self._load_model() + + # 缓存配置 + if COND_CACHE or UNCOND_CACHE: + # 保守cache + cache_config = CacheConfig( + method="dit_block_cache", + blocks_count=60, + steps_count=self.num_inference_steps, + step_start=10, + step_interval=3, + step_end=35, + block_start=10, + block_end=50 + ) + self.pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None + self.pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None + self.logger.info("启用缓存配置") + + def _load_model(self): + """加载模型""" + self.logger.info(f"从 {self.path} 加载模型...") + + # 设置设备 + if self.device == "npu": + torch.npu.set_device(self.device_id) + + # 加载 transformer + transformer = QwenImageTransformer2DModel.from_pretrained( + os.path.join(self.path, 'transformer'), + torch_dtype=self.torch_dtype, + device_map=None, # 禁用自动设备映射 + low_cpu_mem_usage=True # 启用CPU低内存模式 + ) + + # 量化配置 + if self.quant_desc_path: + from mindiesd import quantize + self.logger.info("Quantizing Transformer (单独量化核心组件)...") + quantize( + model=transformer, + quant_des_path=self.quant_desc_path, + use_nz=True, + ) + if self.device == "npu": + torch.npu.empty_cache() # 清理NPU显存缓存 + + # 加载 pipeline + self.pipeline = QwenImageEditPlusPipeline.from_pretrained( + self.path, + transformer=transformer, + torch_dtype=self.torch_dtype, + device_map=None, + low_cpu_mem_usage=True + ) + + # VAE优化配置(避免显存溢出) + self.pipeline.vae.use_slicing = True + self.pipeline.vae.use_tiling = True + + # 移动模型到目标设备 + self.pipeline.to(self.device_str) + self.pipeline.set_progress_bar_config(disable=None) # 显示进度条 + + def _get_meta_template(self, meta_template): + """获取元模板""" + class DummyTemplateParser: + def parse_template(self, prompt_template, mode): + return prompt_template + return DummyTemplateParser() + + def _generate(self, input) -> List[Image]: + """Generate result given a input. + + Args: + input (PromptType): A string or PromptDict. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + # 处理输入格式 + images = [] + prompts = [] + neg_prompts = [] + print(f"in _generate") + #self.logger.info(f"输入: {input}") + if isinstance(input, str): + prompts.append(input) + neg_prompts.append("") + elif isinstance(input, list): + # 处理包含图像的输入 + for item in input[0]["prompt"]: + if item["type"] == "image_url": + base64_url = item["image_url"]["url"].split(",")[1] + img = Image.open(io.BytesIO(base64.b64decode(base64_url))).convert("RGB") + images.append(img) + elif item["type"] == "text": + prompts.append(item["text"]) + neg_prompts.append("") + else: + prompts.append("") + neg_prompts.append("") + + # 如果没有图像输入,使用默认图像 + if not images: + raise AISBenchRuntimeError(MODEL_CODES.UNKNOWN_ERROR, "QwenImageEditModel requires image input, but can't get image info from input.") + + # 执行推理 + results = [] + for prompt, neg_prompt in zip(prompts, neg_prompts): + # 准备输入参数 + print("in _generate loop") + inputs = { + "image": images, + "prompt": prompt, + "negative_prompt": neg_prompt, + "generator": torch.Generator(device=self.device_str).manual_seed(self.seed), + "true_cfg_scale": self.true_cfg_scale, + "guidance_scale": self.guidance_scale, + "num_inference_steps": self.num_inference_steps, + "num_images_per_prompt": self.num_images_per_prompt, + } + + # 执行推理并计时 + if self.device == "npu": + torch.npu.synchronize() # 昇腾设备同步 + start_time = time.time() + + with torch.inference_mode(): + output = self.pipeline(**inputs) + + if self.device == "npu": + torch.npu.synchronize() + end_time = time.time() + infer_time = end_time - start_time + self.logger.info(f"Current image finish generated, cost: {infer_time:.2f} second.") + + return output + + def encode(self, prompt: str) -> torch.Tensor: + """Encode prompt to tokens. Not necessary for most cases. + + Args: + prompt (str): Input string. + + Returns: + torch.Tensor: Encoded tokens. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `encode` method.') + + def decode(self, tokens: torch.Tensor) -> str: + """Decode tokens to text. Not necessary for most cases. + + Args: + tokens (torch.Tensor): Input tokens. + + Returns: + str: Decoded text. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `decode` method.') + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + # 对于图像编辑模型,token长度计算可能不同,这里返回一个默认值 + return len(prompt.split()) + + def generate(self, inputs, outputs, **kwargs): + """Generate completion from inputs. + + Args: + inputs: Inputs for generation. + max_out_lens: Maximum output lengths. + **kwargs: Additional keyword arguments. + + Returns: + List[str]: Generated completions. + """ + #self.logger.info(f"model {inputs=}") + if not isinstance(inputs, list): + inputs = [inputs] + + for i, input in enumerate(inputs): + result = self._generate(input) + # result is QwenImagePipelineOutput with 'images' attribute + if hasattr(result, 'images') and result.images: + outputs[i].success = True + outputs[i].content = result.images # 将图像列表赋值给 content + else: + outputs[i].success = False + outputs[i].content = [""] diff --git a/ais_bench/benchmark/models/output.py b/ais_bench/benchmark/models/output.py index f0676ae3..aaf7379b 100644 --- a/ais_bench/benchmark/models/output.py +++ b/ais_bench/benchmark/models/output.py @@ -1,5 +1,9 @@ +import os import time from abc import abstractmethod +from typing import Union + +from PIL import Image import numpy as np @@ -174,4 +178,62 @@ def update_extra_details_data_from_text_response(self, text_response: dict) -> N for item in text_response.get("choices", []): message = item.get("message", {}) self.extra_details_data["message"] = message - return # only one message is allowed \ No newline at end of file + return # only one message is allowed + + +LLM_META_DATA_TYPE = Union[Image.Image, str] + + +class LMMOutput(Output): + def __init__(self, perf_mode: bool = False) -> None: + super().__init__(perf_mode) + self.content: list[LLM_META_DATA_TYPE] = [""] + self.HANDLER_MAP = { + Image.Image: self._handle_image, + str: self._handle_text, + } + + def get_prediction(self, save_dir: str) -> dict: + """Get the final prediction by combining content and reasoning. + + Returns: + dict: Combined prediction content + """ + output = [] + for i, item in enumerate(self.content): + output.append(self.HANDLER_MAP[type(item)](save_dir, i)) + if len(output) == 1: + return output[0] + else: + return output + + def _handle_image(self, save_dir: str, index: int) -> str: + """Handle image content. + + Args: + save_dir: Directory to save image + index: Index of image in content list + + Returns: + str: Last two levels of image path + """ + image = self.content[index] + image_path = os.path.join(save_dir, f"image_{self.uuid}_{index}.png") + if os.path.exists(image_path): + os.remove(image_path) + image.save(image_path) + return os.path.join(*image_path.split(os.sep)[-2:]) + + def _handle_text(self, save_dir: str, index: int) -> str: + """Handle text content. + + Args: + save_dir: Directory to save text + index: Index of text in content list + + Returns: + str: Text content + """ + return self.content[index] + + diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..2da7ddc8 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py @@ -0,0 +1,67 @@ +import uuid +from typing import List, Optional + +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.registry import ICL_INFERENCERS +from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever +from ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import LMMGenInferencerOutputHandler + + +@ICL_INFERENCERS.register_module() +class LMMGenInferencer(GenInferencer): + def __init__( + self, + model_cfg, + stopping_criteria: List[str] = [], + batch_size: Optional[int] = 1, + mode: Optional[str] = "infer", + gen_field_replace_token: Optional[str] = "", + output_json_filepath: Optional[str] = "./icl_inference_output", + save_every: Optional[int] = 1, + **kwargs, + ) -> None: + super().__init__( + model_cfg=model_cfg, + stopping_criteria=stopping_criteria, + batch_size=batch_size, + mode=mode, + gen_field_replace_token=gen_field_replace_token, + output_json_filepath=output_json_filepath, + save_every=save_every, + **kwargs, + ) + + self.output_handler = LMMGenInferencerOutputHandler(perf_mode=self.perf_mode, + save_every=self.save_every) + def inference(self, retriever: BaseRetriever, output_json_filepath: Optional[str] = None) -> List: + self.output_handler.set_output_path(output_json_filepath) + return super().inference(retriever, output_json_filepath) + + def batch_inference( + self, + datum, + ) -> None: + """Perform batch inference on the given dataloader. + + Args: + dataloader: DataLoader containing the inference data + + Returns: + List of inference results + """ + indexs = datum.pop("index") + inputs = datum.pop("prompt") + data_abbrs = datum.pop("data_abbr") + outputs = [LMMOutput(self.perf_mode) for _ in range(len(indexs))] + for output in outputs: + output.uuid = str(uuid.uuid4()).replace("-", "") + golds = datum.pop("gold", [None] * len(inputs)) + self.model.generate(inputs, outputs, **datum) + + for index, input, output, data_abbr, gold in zip( + indexs, inputs, outputs, data_abbrs, golds + ): + self.output_handler.report_cache_info_sync( + index, input, output, data_abbr, gold + ) \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py index 42cef866..3f305541 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py @@ -22,6 +22,7 @@ DB_REF_KEY = "__db_ref__" DB_DATA_DIR = "db_data" +BASE64_MAX_DISPLAY_LEN = 256 class BaseInferencerOutputHandler: @@ -56,7 +57,14 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: self.save_every = save_every @abstractmethod - def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, Output], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + + ) -> dict: """ Get the prediction result. @@ -64,7 +72,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference - + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result """ @@ -74,6 +82,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] def get_result( self, conn: sqlite3.Connection, + data_abbr: str, input: Union[str, List[str]], output: Union[str, Output], gold: Optional[str] = None, @@ -113,7 +122,7 @@ def get_result( if gold: result_data["gold"] = gold else: - result_data = self.get_prediction_result(output, gold=gold, input=input) + result_data = self.get_prediction_result(output, gold=gold, input=input, data_abbr=data_abbr) if not result_data.get("success", True): self.all_success = False if isinstance(output, Output) and hasattr(output, "error_info"): @@ -365,7 +374,7 @@ def run_cache_consumer( try: uid = str(uuid.uuid4())[:8] - result_data = self.get_result(conn, *item[2:]) + result_data = self.get_result(conn, *item[1:]) id, data_abbr = item[0], item[1] json_data = { "data_abbr": data_abbr, diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py index 0b47eec6..8bf8fbb7 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py @@ -11,7 +11,13 @@ class BFCLV3OutputHandler(BaseInferencerOutputHandler): """ Output handler for BFCLV3 inference tasks. """ - def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: FunctionCallOutput, + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: """ Get the prediction result for BFCLV3 inference tasks. @@ -19,6 +25,8 @@ def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] output (FunctionCallOutput): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference (not used in this implementation) + data_abbr (Optional[str]): Abbreviation of the dataset (not used in this implementation) + Returns: dict: Prediction result containing success, uuid, prediction (tool_calls), and inference_log Raises: diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 111799d2..5beeffcb 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -3,7 +3,7 @@ import sqlite3 import uuid -from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler, BASE64_MAX_DISPLAY_LEN from ais_bench.benchmark.models.output import Output from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError @@ -27,6 +27,7 @@ def get_prediction_result( output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", ) -> dict: """ Get the prediction result for accuracy mode. @@ -35,10 +36,26 @@ def get_prediction_result( output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result """ + if ( + isinstance(input, list) + and len(input) > 0 + and isinstance(input[0], dict) + and isinstance(input[0].get("prompt"), list) + ): + for item in input[0]["prompt"]: + if not isinstance(item, dict): + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): + continue + url = image_url.get("url") + if isinstance(url, str) and len(url) > BASE64_MAX_DISPLAY_LEN: + image_url["url"] = url[:BASE64_MAX_DISPLAY_LEN] + " ..." result_data = { "success": ( output.success if isinstance(output, Output) else True diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..db31f349 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -0,0 +1,84 @@ +from typing import List, Optional, Union +import sqlite3 +import uuid +from pathlib import Path + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler, BASE64_MAX_DISPLAY_LEN +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES +from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError + +class LMMGenInferencerOutputHandler(BaseInferencerOutputHandler): + """ + Output handler for generation-based inference tasks. + + This handler specializes in processing generation model outputs, + supporting both performance measurement and accuracy evaluation modes. + It handles different data formats and provides appropriate result storage. + + Attributes: + all_success (bool): Flag indicating if all operations were successful + perf_mode (bool): Whether in performance measurement mode + cache_queue (queue.Queue): Queue for caching results before writing + """ + def set_output_path(self, output_path: str) -> None: + self.output_path = output_path + + def get_prediction_result( + self, + output: Union[str, LMMOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", + ) -> dict: + """ + Get the prediction result for accuracy mode. + + Args: + output (Union[str, LMMOutput]): Output result from inference + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ + try: + save_dir = Path(self.output_path) / f"{data_abbr}_out_file" + if not save_dir.exists(): + save_dir.mkdir(parents=True, exist_ok=True) + if ( + isinstance(input, list) + and len(input) > 0 + and isinstance(input[0], dict) + and isinstance(input[0].get("prompt"), list) + ): + for item in input[0]["prompt"]: + if not isinstance(item, dict): + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): + continue + url = image_url.get("url") + if isinstance(url, str) and len(url) > BASE64_MAX_DISPLAY_LEN: + image_url["url"] = url[:BASE64_MAX_DISPLAY_LEN] + " ..." + result_data = { + "success": ( + output.success if isinstance(output, LMMOutput) else True + ), + "uuid": output.uuid if isinstance(output, LMMOutput) else str(uuid.uuid4()).replace("-", ""), + "origin_prompt": input if input is not None else "", + "prediction": ( + output.get_prediction(save_dir) + if isinstance(output, LMMOutput) + else output + ), + } + if gold: + result_data["gold"] = gold + except Exception as e: + import traceback + print(f"[ERROR] LMMGenInferencerOutputHandler.get_prediction_result failed: {type(e).__name__}: {e}") + print(f"[ERROR] Traceback: {traceback.format_exc()}") + raise RuntimeError(f"Failed to get prediction result: {e}") + return result_data \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py index bf5ac30e..44da5be5 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py @@ -46,7 +46,25 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: super().__init__(save_every) self.perf_mode = perf_mode - def get_prediction_result(self, output: Union[str, PPLResponseOutput], gold: Optional[str] = None, input: Union[str, List[str]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, PPLResponseOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: + """ + Get the prediction result for performance mode. + + Args: + output (Union[str, PPLResponseOutput]): Model output + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ if not isinstance(output, PPLResponseOutput): raise AISBenchImplementationError(ICLI_CODES.UNKNOWN_ERROR, f"Output is not a PPLResponseOutput") result_data = { diff --git a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py index 83d42a4e..34562b5f 100644 --- a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py +++ b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py @@ -39,7 +39,7 @@ def check_mm_template(self): if "prompt_mm" not in data.keys(): return False return True - + def format_mm_url(self, template, entry): """ for mm_custom dataset @@ -103,6 +103,7 @@ def generate_item( template = self.format_mm_url(self.template, entry) template = self._encode_template(template, ice=False) template = template.format_mm(**entry) + for i, item in enumerate(template): if "prompt_mm" in item: template[i]["prompt_mm"] = self.get_mm_template(item) diff --git a/ais_bench/benchmark/tasks/openicl_api_infer.py b/ais_bench/benchmark/tasks/openicl_api_infer.py index 442f0632..6092aa0e 100644 --- a/ais_bench/benchmark/tasks/openicl_api_infer.py +++ b/ais_bench/benchmark/tasks/openicl_api_infer.py @@ -164,7 +164,7 @@ def _get_data_list(self) -> tuple[List, List]: data_abbr = dataset_cfg["abbr"] cur_data_cache = finish_cache_data.get(data_abbr, {}) infer_cfg = dataset_cfg["infer_cfg"] - dataset = build_dataset_from_cfg(dataset_cfg) + dataset = build_dataset_from_cfg(dataset_cfg, task_state_manager=self.task_state_manager) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset retriever_cfg["prompt_template"] = infer_cfg.get("prompt_template", None) @@ -473,6 +473,7 @@ def warm_up(self, data_list: List, task_state_manager: TaskStateManager): def run(self, task_state_manager: TaskStateManager): self.logger.info(f"Task [{task_abbr_from_cfg(self.cfg)}]") + self.task_state_manager = task_state_manager self.inferencer:BaseApiInferencer = ICL_INFERENCERS.build(self.inferencer_cfg) self.clean_failed_results() diff --git a/ais_bench/benchmark/tasks/openicl_eval.py b/ais_bench/benchmark/tasks/openicl_eval.py index 6cbb3a3c..d0f2ff60 100644 --- a/ais_bench/benchmark/tasks/openicl_eval.py +++ b/ais_bench/benchmark/tasks/openicl_eval.py @@ -71,7 +71,8 @@ def get_command(self, cfg_path, template): command = f'{python} {script_path} {cfg_path}' return template.format(task_cmd=command) - def run(self): + def run(self, task_state_manager: TaskStateManager): + self.task_state_manager = task_state_manager for dataset_cfg in self.dataset_cfgs: self.dataset_cfg = dataset_cfg # Load Dataset @@ -111,8 +112,7 @@ def _score(self): "k":k, "n":n }) - - test_set = build_dataset_from_cfg(self.dataset_cfg).test + test_set = build_dataset_from_cfg(self.dataset_cfg, task_state_manager=self.task_state_manager).test # Postprocess dataset if necessary if 'dataset_postprocessor' in self.eval_cfg: self.logger.debug(f"Dataset postprocessor: {self.eval_cfg['dataset_postprocessor']}") @@ -515,7 +515,7 @@ def parse_args(): start_time = time.perf_counter() try: evaluator: OpenICLEvalTask = OpenICLEvalTask(cfg) - evaluator.run() + evaluator.run(task_state_manager) except Exception as e: task_state_manager.update_task_state({"status": "error"}) raise e diff --git a/ais_bench/benchmark/tasks/openicl_infer.py b/ais_bench/benchmark/tasks/openicl_infer.py index 5d20d67e..8937ca5d 100644 --- a/ais_bench/benchmark/tasks/openicl_infer.py +++ b/ais_bench/benchmark/tasks/openicl_infer.py @@ -123,7 +123,7 @@ def _inference(self): retrievers = [] for dataset_cfg in self.dataset_cfgs: infer_cfg = dataset_cfg["infer_cfg"] - dataset = build_dataset_from_cfg(dataset_cfg) + dataset = build_dataset_from_cfg(dataset_cfg, task_state_manager=self.task_state_manager) retriever_cfg = infer_cfg["retriever"].copy() retriever_cfg["dataset"] = dataset retriever_cfg["prompt_template"] = infer_cfg.get("prompt_template", None) From b12baa0e839c8598a70efcf99aad5183e834c3f9 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Wed, 4 Mar 2026 17:07:37 +0800 Subject: [PATCH 2/2] review fix --- .../local_models/qwen_image_edit_mindie_sd.py | 85 +++++++++---------- .../icl_inferencer/icl_lmm_gen_inferencer.py | 11 ++- .../lmm_gen_inferencer_output_handler.py | 6 +- 3 files changed, 48 insertions(+), 54 deletions(-) diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py index 60eeb2c3..87673ca1 100644 --- a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -20,7 +20,7 @@ _get_meta_template, ) -# 解决 diffuser 0.35.1 torch2.1 报错 +# Fix diffuser 0.35.1 torch2.1 error def custom_op( name, fn=None, @@ -60,29 +60,28 @@ def decorator(func): torch.library.custom_op = custom_op torch.library.register_fake = register_fake -# 导入 qwen_image_edit 相关模块 +# Import qwen_image_edit related modules try: from ais_bench.third_party.mindie_sd.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel from ais_bench.third_party.mindie_sd.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline from mindiesd import CacheConfig, CacheAgent except ImportError as e: - raise ImportError(f"请确保 qwenimage_edit 模块在 Python 路径中: {e}") + raise ImportError(f"Please ensure qwenimage_edit module is in Python path: {e}") PromptType = Union[PromptList, str] -# 模型推理相关配置常量 -DEFAULT_MODEL_PATH = "/home/yanhe/models/Qwen-Image-Edit-2509/" +# Model inference related config constants DEFAULT_TORCH_DTYPE = "bfloat16" DEFAULT_DEVICE = "npu" DEFAULT_DEVICE_ID = 0 -DEFAULT_NUM_INFERENCE_STEPS = 1 # 40 +DEFAULT_NUM_INFERENCE_STEPS = 40 # 40 DEFAULT_TRUE_CFG_SCALE = 4.0 DEFAULT_GUIDANCE_SCALE = 1.0 DEFAULT_SEED = 0 DEFAULT_NUM_IMAGES_PER_PROMPT = 1 DEFAULT_QUANT_DESC_PATH = None -# 缓存配置开关 +# Cache config switches COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) @@ -100,7 +99,7 @@ class QwenImageEditModel(BaseLMModel): """ def __init__(self, - path: str = DEFAULT_MODEL_PATH, + path: str, device_kwargs: dict = dict(), infer_kwargs: dict = dict(), meta_template: Optional[Dict] = None, @@ -110,10 +109,9 @@ def __init__(self, self.max_out_len = other_kwargs.get('max_out_len', None) self.template_parser = _get_meta_template(meta_template) - # 设备配置 + # Device config self.device = device_kwargs.get('device', DEFAULT_DEVICE) - #self.device_id = device_kwargs.get('device_id', DEFAULT_DEVICE_ID) - # 在这里声明环境变量 + # Declare environment variable here self.logger.debug(f"device id from kwargs: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}") os.environ["ASCEND_RT_VISIBLE_DEVICES"] = f"{device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}" self.device_id = DEFAULT_DEVICE_ID @@ -121,11 +119,11 @@ def __init__(self, self.logger.debug(f"device_str: {self.device_str}; device_id: {self.device_id}") self.logger.debug(f"ASCEND_RT_VISIBLE_DEVICES: {os.getenv('ASCEND_RT_VISIBLE_DEVICES')}") - # 模型配置 + # Model config self.torch_dtype = other_kwargs.get('torch_dtype', DEFAULT_TORCH_DTYPE) self.torch_dtype = torch.bfloat16 if self.torch_dtype == "bfloat16" else torch.float32 - # 推理配置 + # Inference config self.num_inference_steps = infer_kwargs.get('num_inference_steps', DEFAULT_NUM_INFERENCE_STEPS) self.true_cfg_scale = infer_kwargs.get('true_cfg_scale', DEFAULT_TRUE_CFG_SCALE) self.guidance_scale = infer_kwargs.get('guidance_scale', DEFAULT_GUIDANCE_SCALE) @@ -140,12 +138,12 @@ def __init__(self, f"quant_desc_path: {self.quant_desc_path}" ) - # 加载模型 + # Load model self._load_model() - # 缓存配置 + # Cache config if COND_CACHE or UNCOND_CACHE: - # 保守cache + # Conservative cache cache_config = CacheConfig( method="dit_block_cache", blocks_count=60, @@ -158,37 +156,37 @@ def __init__(self, ) self.pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None self.pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None - self.logger.info("启用缓存配置") + self.logger.info("Cache configuration enabled") def _load_model(self): - """加载模型""" - self.logger.info(f"从 {self.path} 加载模型...") + """Load model""" + self.logger.info(f"Loading model from {self.path}...") - # 设置设备 + # Set device if self.device == "npu": torch.npu.set_device(self.device_id) - # 加载 transformer + # Load transformer transformer = QwenImageTransformer2DModel.from_pretrained( os.path.join(self.path, 'transformer'), torch_dtype=self.torch_dtype, - device_map=None, # 禁用自动设备映射 - low_cpu_mem_usage=True # 启用CPU低内存模式 + device_map=None, # Disable auto device mapping + low_cpu_mem_usage=True # Enable CPU low memory mode ) - # 量化配置 + # Quantization config if self.quant_desc_path: from mindiesd import quantize - self.logger.info("Quantizing Transformer (单独量化核心组件)...") + self.logger.info("Quantizing Transformer (quantizing core component separately)...") quantize( model=transformer, quant_des_path=self.quant_desc_path, use_nz=True, ) if self.device == "npu": - torch.npu.empty_cache() # 清理NPU显存缓存 + torch.npu.empty_cache() # Clear NPU memory cache - # 加载 pipeline + # Load pipeline self.pipeline = QwenImageEditPlusPipeline.from_pretrained( self.path, transformer=transformer, @@ -197,16 +195,16 @@ def _load_model(self): low_cpu_mem_usage=True ) - # VAE优化配置(避免显存溢出) + # VAE optimization config (avoid memory overflow) self.pipeline.vae.use_slicing = True self.pipeline.vae.use_tiling = True - # 移动模型到目标设备 + # Move model to target device self.pipeline.to(self.device_str) - self.pipeline.set_progress_bar_config(disable=None) # 显示进度条 + self.pipeline.set_progress_bar_config(disable=None) # Show progress bar def _get_meta_template(self, meta_template): - """获取元模板""" + """Get meta template""" class DummyTemplateParser: def parse_template(self, prompt_template, mode): return prompt_template @@ -224,17 +222,15 @@ def _generate(self, input) -> List[Image]: Returns: str: The generated string. """ - # 处理输入格式 + # Process input format images = [] prompts = [] neg_prompts = [] - print(f"in _generate") - #self.logger.info(f"输入: {input}") if isinstance(input, str): prompts.append(input) neg_prompts.append("") elif isinstance(input, list): - # 处理包含图像的输入 + # Process input containing images for item in input[0]["prompt"]: if item["type"] == "image_url": base64_url = item["image_url"]["url"].split(",")[1] @@ -247,15 +243,14 @@ def _generate(self, input) -> List[Image]: prompts.append("") neg_prompts.append("") - # 如果没有图像输入,使用默认图像 + # If no image input, use default image if not images: raise AISBenchRuntimeError(MODEL_CODES.UNKNOWN_ERROR, "QwenImageEditModel requires image input, but can't get image info from input.") - # 执行推理 + # Execute inference results = [] for prompt, neg_prompt in zip(prompts, neg_prompts): - # 准备输入参数 - print("in _generate loop") + # Prepare input parameters inputs = { "image": images, "prompt": prompt, @@ -267,19 +262,15 @@ def _generate(self, input) -> List[Image]: "num_images_per_prompt": self.num_images_per_prompt, } - # 执行推理并计时 + # Execute inference and time it if self.device == "npu": - torch.npu.synchronize() # 昇腾设备同步 - start_time = time.time() + torch.npu.synchronize() # Ascend device sync with torch.inference_mode(): output = self.pipeline(**inputs) if self.device == "npu": torch.npu.synchronize() - end_time = time.time() - infer_time = end_time - start_time - self.logger.info(f"Current image finish generated, cost: {infer_time:.2f} second.") return output @@ -314,7 +305,7 @@ def get_token_len(self, prompt: str) -> int: Returns: int: Length of the input tokens """ - # 对于图像编辑模型,token长度计算可能不同,这里返回一个默认值 + # For image editing models, token length calculation may differ, return a default value here return len(prompt.split()) def generate(self, inputs, outputs, **kwargs): @@ -337,7 +328,7 @@ def generate(self, inputs, outputs, **kwargs): # result is QwenImagePipelineOutput with 'images' attribute if hasattr(result, 'images') and result.images: outputs[i].success = True - outputs[i].content = result.images # 将图像列表赋值给 content + outputs[i].content = result.images # Assign image list to content else: outputs[i].success = False outputs[i].content = [""] diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py index 2da7ddc8..35a1c227 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py @@ -32,9 +32,14 @@ def __init__( **kwargs, ) - self.output_handler = LMMGenInferencerOutputHandler(perf_mode=self.perf_mode, - save_every=self.save_every) - def inference(self, retriever: BaseRetriever, output_json_filepath: Optional[str] = None) -> List: + self.output_handler = LMMGenInferencerOutputHandler(perf_mode=self.perf_mode, save_every=self.save_every) + + def inference( + self, + retriever: BaseRetriever, + output_json_filepath: + Optional[str] = None + ) -> List: self.output_handler.set_output_path(output_json_filepath) return super().inference(retriever, output_json_filepath) diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py index db31f349..8a6a15b5 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -6,7 +6,7 @@ from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler, BASE64_MAX_DISPLAY_LEN from ais_bench.benchmark.models.output import LMMOutput from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES -from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError +from ais_bench.benchmark.utils.logging.exceptions import AISBenchRuntimeError class LMMGenInferencerOutputHandler(BaseInferencerOutputHandler): """ @@ -78,7 +78,5 @@ def get_prediction_result( result_data["gold"] = gold except Exception as e: import traceback - print(f"[ERROR] LMMGenInferencerOutputHandler.get_prediction_result failed: {type(e).__name__}: {e}") - print(f"[ERROR] Traceback: {traceback.format_exc()}") - raise RuntimeError(f"Failed to get prediction result: {e}") + raise AISBenchRuntimeError(ICLI_CODES.UNKNOWN_ERROR, f"Failed to get prediction result: {e} \n Traceback: {traceback.format_exc()}") return result_data \ No newline at end of file