diff --git a/src/icatcher/cli.py b/src/icatcher/cli.py index 59bb15a..248024a 100644 --- a/src/icatcher/cli.py +++ b/src/icatcher/cli.py @@ -192,11 +192,19 @@ def load_models(opt, download_only=False): face_detector_model_file = file_paths[ file_names.index("Resnet50_Final.pth") ] - face_detector_model = RetinaFace( + if opt.device.startswith("mps"): + face_detector_model = RetinaFace( + gpu_id=opt.gpu_id, + model_path=face_detector_model_file, + network="resnet50", + device="mps", + ) + else: + face_detector_model = RetinaFace( gpu_id=opt.gpu_id, model_path=face_detector_model_file, network="resnet50", - ) + ) elif opt.fd_model == "opencv_dnn": face_detector_model_file = file_paths[ file_names.index("face_model.caffemodel") @@ -215,6 +223,10 @@ def load_models(opt, download_only=False): state_dict = torch.load( str(path_to_gaze_model), map_location=torch.device(opt.device) ) + elif opt.device.startswith("mps"): + state_dict = torch.load( + str(path_to_gaze_model), map_location=torch.device("mps") + ) else: state_dict = torch.load(str(path_to_gaze_model)) try: diff --git a/src/icatcher/options.py b/src/icatcher/options.py index bcd040f..4739778 100644 --- a/src/icatcher/options.py +++ b/src/icatcher/options.py @@ -272,13 +272,17 @@ def parse_arguments(my_string=None): args.device = "cpu" else: import os - - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) - args.device = "cuda:{}".format(0) import torch - if not torch.cuda.is_available(): - raise ValueError("GPU is not available. Was torch compiled with CUDA?") + if torch.cuda.is_available(): + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) + args.device = f"cuda:{args.gpu_id}" + else: + if torch.backends.mps.is_available(): + args.device = f"mps:{args.gpu_id}" + else: + raise ValueError("GPU is not available. Was torch compiled with CUDA or MPS?") + # figure out how many cpus can be used use_cpu = True if args.gpu_id == -1 else False if use_cpu: diff --git a/tests/test_gaze_model.py b/tests/test_gaze_model.py index 40895d8..62ce5cb 100644 --- a/tests/test_gaze_model.py +++ b/tests/test_gaze_model.py @@ -48,3 +48,21 @@ def test_predict_from_video(args_string): """ args = parse_arguments(args_string) predict_from_video(args) + +@pytest.mark.skipif( + not torch.backends.mps.is_available() or torch.cuda.is_available(), + reason="Requires MPS for running, without a CUDA GPU, to test the functionality of the MPS pipeline." +) +@pytest.mark.parametrize( + "args_string", + [ + "tests/test_data/test_short.mp4 --model icatcher+_lookit_regnet.pth --gpu_id=0", + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --gpu_id=0", + ], +) +def test_predict_from_video_with_mps(args_string): + """ + Ensures that the entire prediction pipeline is run to completion with both gaze models using MPS. + """ + args = parse_arguments(args_string) + predict_from_video(args) \ No newline at end of file