Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# --- Python ---
__pycache__/
*.py[cod]
*$py.class
*.so
.Python

# --- Virtual environments ---
.venv/
venv/
ENV/
env/

# --- Packaging / build ---
build/
dist/
*.egg-info/
.eggs/
pip-wheel-metadata/

# --- Test / coverage ---
.pytest_cache/
.coverage
coverage.xml
.tox/
.nox/
htmlcov/

# --- Jupyter ---
.ipynb_checkpoints/

# --- Logs ---
*.log

# --- IDEs / editors ---
.vscode/
.idea/

# --- OS files ---
.DS_Store
Thumbs.db

# --- uv / caches ---
.uv/

# NOTE: uv stores caches under ~/.cache/uv by default (outside the repo),
# but if you configure a project-local cache, ignore it here.
.cache/

# --- ML artifacts / large outputs (customize as needed) ---
**/checkpoints/
**/runs/
**/wandb/
**/mlruns/
**/tensorboard/
**/*.ckpt
**/*.pt
**/*.pth
**/*.bin
**/*.safetensors
**/*.onnx

# --- Local data / outputs (keep raw datasets under version control only if intended) ---
data/**/output/
data/**/outputs/
data/**/predictions/


14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ conda create -n CRAG python=3.11
pip install -r requirements.txt
```

### macOS note
`requirements.txt` contains several Linux/CUDA-only packages (e.g. `flash-attn`, `deepspeed`, `nvidia-*`).
On macOS, install the provided CPU-friendly set instead:
```
pip install -r requirements-macos.txt
```
Additionally, on Apple Silicon you can use PyTorch's `mps` backend (Metal) for evaluator training and inference:
```
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
```
This enables GPU acceleration on macOS when `mps` is available.

Note: `scripts/CRAG_Inference.py` can use vLLM if installed, but on macOS we do not rely on it by default because vLLM support is still experimental and may require building from source (Xcode toolchain). The script will automatically fall back to a pure-Transformers generator if `vllm` is not available (or if you pass `--generator_backend transformers`).

## Download
- Download the **eval_data** created by [Self-RAG (Asai et al., 2023)](https://github.com/AkariAsai/self-rag) on PopQA, PubQA, Bio and Arc_challenge with retrieved results
- Download the **LLaMA-2** fine-tuned by [Self-RAG (Asai et al., 2023)](https://huggingface.co/selfrag/selfrag_llama2_7b).
Expand Down
123 changes: 123 additions & 0 deletions requirements-macos.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
absl-py==2.0.0
accelerate==0.22.0
aiohttp==3.9.1
aioprometheus==23.3.0
aiosignal==1.3.1
anyio==3.7.1
attrs==23.1.0
blis==0.7.11
cachetools==5.3.2
catalogue==2.0.10
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
cloudpathlib==0.16.0
colorama==0.4.6
confection==0.1.4
cymem==2.0.8
dataclasses==0.6
datasets==2.15.0
dill==0.3.7
distro==1.8.0
einops==0.7.0
evaluate==0.4.1
fastapi==0.105.0
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.10.0
google-auth==2.25.2
google-auth-oauthlib==1.2.0
grpcio==1.60.0
h11==0.14.0
hjson==3.1.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.25.2
huggingface-hub==0.19.4
idna==3.6
Jinja2==3.1.2
joblib==1.3.2
jsonlines==4.0.0
jsonschema==4.20.0
jsonschema-specifications==2023.11.2
langcodes==3.3.0
lxml==4.9.3
Markdown==3.5.1
MarkupSafe==2.1.3
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
multiprocess==0.70.15
murmurhash==1.0.10
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
numpy==1.26.2
oauthlib==3.2.2
openai==1.4.0
orjson==3.9.10
packaging==23.2
pandas==2.1.4
peft==0.7.1
portalocker==2.8.2
preshed==3.0.9
protobuf==4.23.4
psutil==5.9.6
py-cpuinfo==9.0.0
pyarrow==14.0.1
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pydantic==1.10.13
python-dateutil==2.8.2
python-dotenv==1.0.0
pytz==2023.3.post1
PyYAML==6.0.1
quantile-python==1.1
ray==2.8.1
referencing==0.32.0
regex==2023.10.3
requests==2.31.0
requests-oauthlib==1.3.1
responses==0.18.0
rouge-score==0.1.2
rpds-py==0.13.2
rsa==4.9
sacrebleu==2.4.0
safetensors==0.4.1
scikit-learn==1.4.0
scipy==1.12.0
sentencepiece==0.1.99
six==1.16.0
smart-open==6.4.0
sniffio==1.3.0
spacy==3.7.2
spacy-legacy==3.0.12
spacy-loggers==1.0.5
srsly==2.4.8
starlette==0.27.0
sympy==1.12
tabulate==0.9.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
thinc==8.2.2
threadpoolctl==3.2.0
tiktoken==0.5.2
tokenizers==0.15.0
torch==2.1.2
tqdm==4.66.1
transformers==4.36.1
typer==0.9.0
typing_extensions==4.9.0
tzdata==2023.3
urllib3==2.1.0
uvicorn==0.24.0.post1
uvloop==0.19.0
wasabi==1.1.2
watchfiles==0.21.0
weasel==0.3.4
websockets==12.0
Werkzeug==3.0.1
xxhash==3.4.1
yarl==1.9.4

94 changes: 87 additions & 7 deletions scripts/CRAG_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.optim import AdamW
from transformers import get_scheduler

from vllm import LLM, SamplingParams
from transformers import T5Tokenizer, T5ForSequenceClassification
from transformers import AutoTokenizer, AutoModelForCausalLM

Expand Down Expand Up @@ -195,6 +194,57 @@ def process_flag(scores, n_docs, threshold1, threshold2):
tmp_flag = []
return identification_flag

def _select_generator_backend(requested_backend: str):
"""
Decide which generator backend to use.
- If requested_backend is 'auto', prefer vllm when available, otherwise use transformers.
- If explicitly 'vllm', require vllm to be importable.
- If explicitly 'transformers', never import vllm.
"""
requested_backend = (requested_backend or "auto").lower()
if requested_backend not in {"auto", "vllm", "transformers"}:
raise ValueError(f"Unknown generator backend: {requested_backend}")

if requested_backend == "transformers":
return "transformers", None

try:
from vllm import LLM, SamplingParams # type: ignore
return "vllm", (LLM, SamplingParams)
except Exception as e:
if requested_backend == "vllm":
raise ImportError(
"generator_backend='vllm' requested but vllm is not available. "
"Install vllm or use --generator_backend transformers."
) from e
return "transformers", None

def _infer_generator_device(args_device: str | None):
"""
Choose a reasonable torch device string for generator inference.
vLLM has its own device handling; for Transformers fallback we use torch.
"""
if args_device and args_device != "auto":
return args_device
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"

def _transformers_generate_one(prompt: str, tokenizer, model, device: str, max_new_tokens: int = 100):
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
do_sample=False,
max_new_tokens=max_new_tokens,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Remove the prompt prefix if present.
return decoded[len(prompt):] if decoded.startswith(prompt) else decoded

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--generator_path', type=str)
Expand All @@ -207,6 +257,13 @@ def main():
parser.add_argument('--task', type=str)
parser.add_argument('--method', type=str, default="default", choices=['rag', 'crag', 'no_retrieval'])
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument(
'--generator_backend',
type=str,
default="auto",
choices=["auto", "vllm", "transformers"],
help="Generator backend. 'auto' prefers vLLM if available, otherwise uses Transformers.",
)
parser.add_argument('--download_dir', type=str, help="specify vllm model download dir",
default=".cache")
parser.add_argument("--ndocs", type=int, default=-1,
Expand All @@ -220,8 +277,23 @@ def main():
args = parser.parse_args()
args.lower_threshold = -args.lower_threshold

generator = LLM(model=args.generator_path, dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
backend, vllm_syms = _select_generator_backend(args.generator_backend)
sampling_params = None
generator = None
hf_tokenizer = None
hf_model = None
generator_device = _infer_generator_device(args.device)

if backend == "vllm":
LLM, SamplingParams = vllm_syms
generator = LLM(model=args.generator_path, dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
else:
hf_tokenizer = AutoTokenizer.from_pretrained(args.generator_path)
torch_dtype = torch.float16 if generator_device in {"cuda", "mps"} else torch.float32
hf_model = AutoModelForCausalLM.from_pretrained(args.generator_path, torch_dtype=torch_dtype)
hf_model.to(generator_device)
hf_model.eval()

tokenizer = T5Tokenizer.from_pretrained(args.evaluator_path)
model = T5ForSequenceClassification.from_pretrained(args.evaluator_path, num_labels=1)
Expand Down Expand Up @@ -263,14 +335,22 @@ def main():
if args.method != 'no_retrieval':
for i, (q, p) in tqdm(enumerate(zip(queries, paragraphs))):
prompt = format_prompt(i, args.task, q, p, modelname)
pred = generator.generate([prompt], sampling_params)
preds.append(postprocess_answer_option_conditioned(pred[0].outputs[0].text))
if backend == "vllm":
pred = generator.generate([prompt], sampling_params)
text = pred[0].outputs[0].text
else:
text = _transformers_generate_one(prompt, hf_tokenizer, hf_model, generator_device, max_new_tokens=100)
preds.append(postprocess_answer_option_conditioned(text))
else:
for i, q in tqdm(enumerate(queries)):
p = None
prompt = format_prompt(i, args.task, q, p, modelname)
pred = generator.generate([prompt], sampling_params)
preds.append(postprocess_answer_option_conditioned(pred[0].outputs[0].text))
if backend == "vllm":
pred = generator.generate([prompt], sampling_params)
text = pred[0].outputs[0].text
else:
text = _transformers_generate_one(prompt, hf_tokenizer, hf_model, generator_device, max_new_tokens=100)
preds.append(postprocess_answer_option_conditioned(text))

with open(args.output_file, 'w') as f:
f.write('\n'.join(preds))
Expand Down