This repository was archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 68
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Error in Model Loading Code #18
Copy link
Copy link
Open
Description
The str.replace function in contriever.py:126 replaces "encoder." appearing in the keys of the state_dict twice.
Lines 103 to 138 in 39fb220
| def load_retriever(model_path, pooling="average", random_init=False): | |
| # try: check if model exists locally | |
| path = os.path.join(model_path, "checkpoint.pth") | |
| if os.path.exists(path): | |
| pretrained_dict = torch.load(path, map_location="cpu") | |
| opt = pretrained_dict["opt"] | |
| if hasattr(opt, "retriever_model_id"): | |
| retriever_model_id = opt.retriever_model_id | |
| else: | |
| # retriever_model_id = "bert-base-uncased" | |
| retriever_model_id = "bert-base-multilingual-cased" | |
| tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) | |
| cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) | |
| if "xlm" in retriever_model_id: | |
| model_class = XLMRetriever | |
| else: | |
| model_class = Contriever | |
| retriever = model_class(cfg) | |
| pretrained_dict = pretrained_dict["model"] | |
| if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class | |
| pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} | |
| elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class | |
| pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} | |
| retriever.load_state_dict(pretrained_dict, strict=False) | |
| else: | |
| retriever_model_id = model_path | |
| if "xlm" in retriever_model_id: | |
| model_class = XLMRetriever | |
| else: | |
| model_class = Contriever | |
| cfg = utils.load_hf(transformers.AutoConfig, model_path) | |
| tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) | |
| retriever = utils.load_hf(model_class, model_path) | |
| return retriever, tokenizer, retriever_model_id |
Therefore, the line 126 should be changed as follows.
pretrained_dict = {k.replace("encoder.", "", 1): v for k, v in pretrained_dict.items() if "encoder." in k}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels