diff --git a/src/evaluate.py b/src/evaluate.py index dcecef5..f5293cb 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -119,9 +119,10 @@ def parse_eval_configs(): # model.print_network() print('\n\n' + '-*=' * 30 + '\n\n') assert os.path.isfile(configs.pretrained_path), "No file at {}".format(configs.pretrained_path) - model.load_state_dict(torch.load(configs.pretrained_path)) configs.device = torch.device('cpu' if configs.no_cuda else 'cuda:{}'.format(configs.gpu_idx)) + model.load_state_dict(torch.load(configs.pretrained_path, map_location=configs.device)) + model = model.to(device=configs.device) model.eval() diff --git a/src/test.py b/src/test.py index 5f838fe..a04eed8 100644 --- a/src/test.py +++ b/src/test.py @@ -95,9 +95,10 @@ def parse_test_configs(): model.print_network() print('\n\n' + '-*=' * 30 + '\n\n') assert os.path.isfile(configs.pretrained_path), "No file at {}".format(configs.pretrained_path) - model.load_state_dict(torch.load(configs.pretrained_path)) configs.device = torch.device('cpu' if configs.no_cuda else 'cuda:{}'.format(configs.gpu_idx)) + model.load_state_dict(torch.load(configs.pretrained_path, map_location=configs.device)) + model = model.to(device=configs.device) out_cap = None