From f284062ccc26da3189a220f42f0c50fcd376829e Mon Sep 17 00:00:00 2001 From: zhudezhong <1945045703@qq.com> Date: Thu, 15 Jan 2026 21:50:29 +0800 Subject: [PATCH] fix: use MacOS M1 series to train evaluator and inference --- .gitignore | 68 +++++++++++++++++++++ README.md | 14 +++++ requirements-macos.txt | 123 ++++++++++++++++++++++++++++++++++++++ scripts/CRAG_Inference.py | 94 ++++++++++++++++++++++++++--- 4 files changed, 292 insertions(+), 7 deletions(-) create mode 100644 .gitignore create mode 100644 requirements-macos.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a477e1c --- /dev/null +++ b/.gitignore @@ -0,0 +1,68 @@ +# --- Python --- +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# --- Virtual environments --- +.venv/ +venv/ +ENV/ +env/ + +# --- Packaging / build --- +build/ +dist/ +*.egg-info/ +.eggs/ +pip-wheel-metadata/ + +# --- Test / coverage --- +.pytest_cache/ +.coverage +coverage.xml +.tox/ +.nox/ +htmlcov/ + +# --- Jupyter --- +.ipynb_checkpoints/ + +# --- Logs --- +*.log + +# --- IDEs / editors --- +.vscode/ +.idea/ + +# --- OS files --- +.DS_Store +Thumbs.db + +# --- uv / caches --- +.uv/ + +# NOTE: uv stores caches under ~/.cache/uv by default (outside the repo), +# but if you configure a project-local cache, ignore it here. +.cache/ + +# --- ML artifacts / large outputs (customize as needed) --- +**/checkpoints/ +**/runs/ +**/wandb/ +**/mlruns/ +**/tensorboard/ +**/*.ckpt +**/*.pt +**/*.pth +**/*.bin +**/*.safetensors +**/*.onnx + +# --- Local data / outputs (keep raw datasets under version control only if intended) --- +data/**/output/ +data/**/outputs/ +data/**/predictions/ + + diff --git a/README.md b/README.md index ac82034..4b596c3 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,20 @@ conda create -n CRAG python=3.11 pip install -r requirements.txt ``` +### macOS note +`requirements.txt` contains several Linux/CUDA-only packages (e.g. `flash-attn`, `deepspeed`, `nvidia-*`). +On macOS, install the provided CPU-friendly set instead: +``` +pip install -r requirements-macos.txt +``` +Additionally, on Apple Silicon you can use PyTorch's `mps` backend (Metal) for evaluator training and inference: +``` +device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") +``` +This enables GPU acceleration on macOS when `mps` is available. + +Note: `scripts/CRAG_Inference.py` can use vLLM if installed, but on macOS we do not rely on it by default because vLLM support is still experimental and may require building from source (Xcode toolchain). The script will automatically fall back to a pure-Transformers generator if `vllm` is not available (or if you pass `--generator_backend transformers`). + ## Download - Download the **eval_data** created by [Self-RAG (Asai et al., 2023)](https://github.com/AkariAsai/self-rag) on PopQA, PubQA, Bio and Arc_challenge with retrieved results - Download the **LLaMA-2** fine-tuned by [Self-RAG (Asai et al., 2023)](https://huggingface.co/selfrag/selfrag_llama2_7b). diff --git a/requirements-macos.txt b/requirements-macos.txt new file mode 100644 index 0000000..766536f --- /dev/null +++ b/requirements-macos.txt @@ -0,0 +1,123 @@ +absl-py==2.0.0 +accelerate==0.22.0 +aiohttp==3.9.1 +aioprometheus==23.3.0 +aiosignal==1.3.1 +anyio==3.7.1 +attrs==23.1.0 +blis==0.7.11 +cachetools==5.3.2 +catalogue==2.0.10 +certifi==2023.11.17 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpathlib==0.16.0 +colorama==0.4.6 +confection==0.1.4 +cymem==2.0.8 +dataclasses==0.6 +datasets==2.15.0 +dill==0.3.7 +distro==1.8.0 +einops==0.7.0 +evaluate==0.4.1 +fastapi==0.105.0 +filelock==3.13.1 +frozenlist==1.4.1 +fsspec==2023.10.0 +google-auth==2.25.2 +google-auth-oauthlib==1.2.0 +grpcio==1.60.0 +h11==0.14.0 +hjson==3.1.0 +httpcore==1.0.2 +httptools==0.6.1 +httpx==0.25.2 +huggingface-hub==0.19.4 +idna==3.6 +Jinja2==3.1.2 +joblib==1.3.2 +jsonlines==4.0.0 +jsonschema==4.20.0 +jsonschema-specifications==2023.11.2 +langcodes==3.3.0 +lxml==4.9.3 +Markdown==3.5.1 +MarkupSafe==2.1.3 +mpmath==1.3.0 +msgpack==1.0.7 +multidict==6.0.4 +multiprocess==0.70.15 +murmurhash==1.0.10 +networkx==3.2.1 +ninja==1.11.1.1 +nltk==3.8.1 +numpy==1.26.2 +oauthlib==3.2.2 +openai==1.4.0 +orjson==3.9.10 +packaging==23.2 +pandas==2.1.4 +peft==0.7.1 +portalocker==2.8.2 +preshed==3.0.9 +protobuf==4.23.4 +psutil==5.9.6 +py-cpuinfo==9.0.0 +pyarrow==14.0.1 +pyarrow-hotfix==0.6 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pydantic==1.10.13 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytz==2023.3.post1 +PyYAML==6.0.1 +quantile-python==1.1 +ray==2.8.1 +referencing==0.32.0 +regex==2023.10.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +responses==0.18.0 +rouge-score==0.1.2 +rpds-py==0.13.2 +rsa==4.9 +sacrebleu==2.4.0 +safetensors==0.4.1 +scikit-learn==1.4.0 +scipy==1.12.0 +sentencepiece==0.1.99 +six==1.16.0 +smart-open==6.4.0 +sniffio==1.3.0 +spacy==3.7.2 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.4.8 +starlette==0.27.0 +sympy==1.12 +tabulate==0.9.0 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +thinc==8.2.2 +threadpoolctl==3.2.0 +tiktoken==0.5.2 +tokenizers==0.15.0 +torch==2.1.2 +tqdm==4.66.1 +transformers==4.36.1 +typer==0.9.0 +typing_extensions==4.9.0 +tzdata==2023.3 +urllib3==2.1.0 +uvicorn==0.24.0.post1 +uvloop==0.19.0 +wasabi==1.1.2 +watchfiles==0.21.0 +weasel==0.3.4 +websockets==12.0 +Werkzeug==3.0.1 +xxhash==3.4.1 +yarl==1.9.4 + diff --git a/scripts/CRAG_Inference.py b/scripts/CRAG_Inference.py index 42ecfda..a70b469 100644 --- a/scripts/CRAG_Inference.py +++ b/scripts/CRAG_Inference.py @@ -15,7 +15,6 @@ from torch.optim import AdamW from transformers import get_scheduler -from vllm import LLM, SamplingParams from transformers import T5Tokenizer, T5ForSequenceClassification from transformers import AutoTokenizer, AutoModelForCausalLM @@ -195,6 +194,57 @@ def process_flag(scores, n_docs, threshold1, threshold2): tmp_flag = [] return identification_flag +def _select_generator_backend(requested_backend: str): + """ + Decide which generator backend to use. + - If requested_backend is 'auto', prefer vllm when available, otherwise use transformers. + - If explicitly 'vllm', require vllm to be importable. + - If explicitly 'transformers', never import vllm. + """ + requested_backend = (requested_backend or "auto").lower() + if requested_backend not in {"auto", "vllm", "transformers"}: + raise ValueError(f"Unknown generator backend: {requested_backend}") + + if requested_backend == "transformers": + return "transformers", None + + try: + from vllm import LLM, SamplingParams # type: ignore + return "vllm", (LLM, SamplingParams) + except Exception as e: + if requested_backend == "vllm": + raise ImportError( + "generator_backend='vllm' requested but vllm is not available. " + "Install vllm or use --generator_backend transformers." + ) from e + return "transformers", None + +def _infer_generator_device(args_device: str | None): + """ + Choose a reasonable torch device string for generator inference. + vLLM has its own device handling; for Transformers fallback we use torch. + """ + if args_device and args_device != "auto": + return args_device + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + return "cpu" + +def _transformers_generate_one(prompt: str, tokenizer, model, device: str, max_new_tokens: int = 100): + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + with torch.no_grad(): + outputs = model.generate( + **inputs, + do_sample=False, + max_new_tokens=max_new_tokens, + ) + decoded = tokenizer.decode(outputs[0], skip_special_tokens=False) + # Remove the prompt prefix if present. + return decoded[len(prompt):] if decoded.startswith(prompt) else decoded + def main(): parser = argparse.ArgumentParser() parser.add_argument('--generator_path', type=str) @@ -207,6 +257,13 @@ def main(): parser.add_argument('--task', type=str) parser.add_argument('--method', type=str, default="default", choices=['rag', 'crag', 'no_retrieval']) parser.add_argument('--device', type=str, default="cuda") + parser.add_argument( + '--generator_backend', + type=str, + default="auto", + choices=["auto", "vllm", "transformers"], + help="Generator backend. 'auto' prefers vLLM if available, otherwise uses Transformers.", + ) parser.add_argument('--download_dir', type=str, help="specify vllm model download dir", default=".cache") parser.add_argument("--ndocs", type=int, default=-1, @@ -220,8 +277,23 @@ def main(): args = parser.parse_args() args.lower_threshold = -args.lower_threshold - generator = LLM(model=args.generator_path, dtype="half") - sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False) + backend, vllm_syms = _select_generator_backend(args.generator_backend) + sampling_params = None + generator = None + hf_tokenizer = None + hf_model = None + generator_device = _infer_generator_device(args.device) + + if backend == "vllm": + LLM, SamplingParams = vllm_syms + generator = LLM(model=args.generator_path, dtype="half") + sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False) + else: + hf_tokenizer = AutoTokenizer.from_pretrained(args.generator_path) + torch_dtype = torch.float16 if generator_device in {"cuda", "mps"} else torch.float32 + hf_model = AutoModelForCausalLM.from_pretrained(args.generator_path, torch_dtype=torch_dtype) + hf_model.to(generator_device) + hf_model.eval() tokenizer = T5Tokenizer.from_pretrained(args.evaluator_path) model = T5ForSequenceClassification.from_pretrained(args.evaluator_path, num_labels=1) @@ -263,14 +335,22 @@ def main(): if args.method != 'no_retrieval': for i, (q, p) in tqdm(enumerate(zip(queries, paragraphs))): prompt = format_prompt(i, args.task, q, p, modelname) - pred = generator.generate([prompt], sampling_params) - preds.append(postprocess_answer_option_conditioned(pred[0].outputs[0].text)) + if backend == "vllm": + pred = generator.generate([prompt], sampling_params) + text = pred[0].outputs[0].text + else: + text = _transformers_generate_one(prompt, hf_tokenizer, hf_model, generator_device, max_new_tokens=100) + preds.append(postprocess_answer_option_conditioned(text)) else: for i, q in tqdm(enumerate(queries)): p = None prompt = format_prompt(i, args.task, q, p, modelname) - pred = generator.generate([prompt], sampling_params) - preds.append(postprocess_answer_option_conditioned(pred[0].outputs[0].text)) + if backend == "vllm": + pred = generator.generate([prompt], sampling_params) + text = pred[0].outputs[0].text + else: + text = _transformers_generate_one(prompt, hf_tokenizer, hf_model, generator_device, max_new_tokens=100) + preds.append(postprocess_answer_option_conditioned(text)) with open(args.output_file, 'w') as f: f.write('\n'.join(preds))