diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 4c28f174..ce1dd8bb 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -114,9 +114,22 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): class JudgeInfer(BaseWorker): + def __init__(self, args) -> None: + super().__init__(args) + self.judge_model_type = None + def update_cfg(self, cfg: ConfigDict) -> None: + for dataset_cfg in cfg["datasets"]: + judge_infer_cfg = dataset_cfg.get("judge_infer_cfg") + if judge_infer_cfg: + self.judge_model_type = judge_infer_cfg["judge_model"]["attr"] + + if self.judge_model_type is None: + logger.debug("Skip Judge Infer") + return cfg + def get_task_type() -> str: - if cfg["datasets"][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service": + if self.judge_model_type == "service": return get_config_type(OpenICLApiInferTask) else: return get_config_type(OpenICLInferTask) @@ -141,6 +154,10 @@ def get_task_type() -> str: return cfg def do_work(self, cfg: ConfigDict): + if self.judge_model_type is None: + logger.debug("Skip Judge Infer") + return + partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner) logger.info("Starting inference tasks...") self._cfg_pre_process(cfg)