Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,22 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict):


class JudgeInfer(BaseWorker):
def __init__(self, args) -> None:
super().__init__(args)
self.judge_model_type = None

def update_cfg(self, cfg: ConfigDict) -> None:
for dataset_cfg in cfg["datasets"]:
judge_infer_cfg = dataset_cfg.get("judge_infer_cfg")
if judge_infer_cfg:
self.judge_model_type = judge_infer_cfg["judge_model"]["attr"]
Comment on lines +122 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop iterates through cfg["datasets"] and updates self.judge_model_type if judge_infer_cfg is found. If multiple datasets have judge_infer_cfg with potentially different judge_model.attr values, self.judge_model_type will be set to the value from the last dataset in the list that contains judge_infer_cfg. This self.judge_model_type is then used globally for get_task_type(). If the system expects all judge_infer_cfgs to have a consistent judge_model.attr or if different judge_model.attr values across datasets should lead to different behaviors, this current implementation might lead to unexpected results. Consider adding a check to ensure consistency across all judge_infer_cfgs if this is a requirement, or explicitly document that only the last encountered judge_model.attr is used.


if self.judge_model_type is None:
logger.debug("Skip Judge Infer")
return cfg

def get_task_type() -> str:
if cfg["datasets"][0]["judge_infer_cfg"]["judge_model"]["attr"] == "service":
if self.judge_model_type == "service":
return get_config_type(OpenICLApiInferTask)
else:
return get_config_type(OpenICLInferTask)
Expand All @@ -141,6 +154,10 @@ def get_task_type() -> str:
return cfg

def do_work(self, cfg: ConfigDict):
if self.judge_model_type is None:
logger.debug("Skip Judge Infer")
return

partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner)
logger.info("Starting inference tasks...")
self._cfg_pre_process(cfg)
Expand Down
Loading