From 4738d10470a97e250548fa3b8a35d3eef8efd673 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Rana Date: Sat, 11 Feb 2023 18:24:47 +0530 Subject: [PATCH] Add previous trajectories to vis_result function and check if vis/nba/ directory exists --- test_nba.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test_nba.py b/test_nba.py index dca1b3a..5912b46 100644 --- a/test_nba.py +++ b/test_nba.py @@ -89,6 +89,8 @@ def draw_result(future,past,mode='pre'): court = plt.imread("datasets/nba/court.png") plt.imshow(court, zorder=0, extent=[Constant.X_MIN, Constant.X_MAX - Constant.DIFF, Constant.Y_MAX, Constant.Y_MIN],alpha=0.5) + if not os.path.isdir("vis/nba/"): + os.makedirs("vis/nba/") if mode == 'pre': plt.savefig('vis/nba/'+str(idx)+'pre.png') else: @@ -103,6 +105,7 @@ def vis_result(test_loader, args): for data in test_loader: future_traj = np.array(data['future_traj']) * args.traj_scale # B,N,T,2 + previous_3D = np.array(data['past_traj']) * args.traj_scale with torch.no_grad(): prediction = model.inference(data) prediction = prediction * args.traj_scale @@ -117,7 +120,7 @@ def vis_result(test_loader, args): best_guess = prediction[indices,np.arange(batch*actor_num)] best_guess = np.reshape(best_guess, (batch,actor_num, args.future_length, 2)) gt = np.reshape(future_traj,(batch,actor_num,args.future_length, 2)) - previous_3D = np.reshape(previous_3D,(batch,actor_num,args.future_length, 2)) + previous_3D = np.reshape(previous_3D,(batch,actor_num,args.past_length, 2)) draw_result(best_guess,previous_3D) draw_result(gt,previous_3D,mode='gt') @@ -272,6 +275,3 @@ def test_model_all(test_loader, args): vis_result(test_loader, args) test_model_all(test_loader, args) - - -