viz
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from baseline_3d_pose import utils
from pathlib import Path

Path.ls = lambda x: list(x.iterdir())
data_path = Path('data')
json_path = Path('json')
imgs_path = Path('imgs')
stat_3d = torch.load(data_path/'stat_3d.pt')
stat_2d = torch.load(data_path/'stat_2d.pt')
train_set_3d = torch.load(data_path/'train_3d.pt')
test_set_3d = torch.load(data_path/'test_3d.pt')
train_set_2d = torch.load(data_path/'train_2d.pt')
test_set_2d = torch.load(data_path/'test_2d.pt')
rcams = torch.load(data_path/'rcams.pt')
mean_2d = stat_2d['mean']
std_2d = stat_2d['std']
dim_use_2d = stat_2d['dim_use']
dim_ignore_2d = stat_2d['dim_ignore']
mean_3d = stat_3d['mean']
std_3d = stat_3d['std']
dim_use_3d = stat_3d['dim_use']
dim_ignore_3d = stat_3d['dim_ignore']
test_2d = 0
test_3d = 0
for key in test_set_2d.keys():
    test_2d += test_set_2d[key].shape[0]
    test_3d += test_set_3d[key].shape[0]
print(test_2d, test_3d)
550644 550644
train_2d = 0
train_3d = 0
for key in train_set_2d.keys():
    train_2d += train_set_2d[key].shape[0]
    train_3d += train_set_3d[key].shape[0]
print(train_2d, train_3d)
1559752 1559752
train_key_list = list(train_set_2d.keys())
kp = utils.get_kp_from_json(json_path/'keypoints.json') * 3
kps = utils.coco_to_skel(kp)
kps_norm = utils.normalize_kp(kps, mean_2d, std_2d, dim_use_2d)
kps_unnorm = utils.unnormalize_data(kps_norm, mean_2d, std_2d, dim_ignore_2d)
kps_unnorm
array([[364.50000434, 319.50000243, 327.00000423, 324.00000123,
        311.99999545, 432.00000259, 341.99999829, 576.        ,
        529.33571189, 587.23551966, 529.38636921, 586.31801763,
        401.9999964 , 314.99999774, 434.99999742, 428.99999725,
        461.99999909, 537.00000037, 530.10512031, 587.89823506,
        530.49446311, 585.47494165, 532.0838752 , 419.7215386 ,
        348.75000434, 234.74999848, 332.99999774, 150.00000597,
        534.55416815, 304.241439  , 325.50000431,  95.99999051,
        534.11018559, 317.90342311, 387.00000229, 147.00000489,
        431.99999845, 233.99999602, 429.00000269, 317.9999989 ,
        533.49380106, 391.72324561, 533.62895965, 384.3068672 ,
        533.64435106, 394.59252866, 533.64435106, 394.59252866,
        534.11018559, 317.90342311, 273.00000771, 152.99999432,
        234.00000806, 200.99999607, 173.99999409, 221.99999499,
        532.72786934, 380.61615719, 532.89670133, 373.03286202,
        532.78290105, 381.24448162, 532.78290105, 381.24448162]])
plt.figure(figsize=(6,6))
gs = GridSpec(1,1)
gs.update(wspace=0.05, hspace=0.05)
ax = plt.subplot(gs[0])

utils.show_2d_pose(kps_unnorm, ax)
ax.invert_yaxis()

plt.show()
key = train_key_list[184]
plt.figure(figsize=(16,6))
gs2 = GridSpec(1,2)
ax1 = plt.subplot(gs2[0])
ax2 = plt.subplot(gs2[1], projection='3d')

idx = 200

ts_2d = utils.unnormalize_data(train_set_2d[key][idx], mean_2d, std_2d, dim_ignore_2d)[0]

ts_3d = utils.unnormalize_data(train_set_3d[key][idx], mean_3d, std_3d, dim_ignore_3d)[0]
ts_3d = utils.cam_to_world_centered(ts_3d, key, rcams)

utils.show_2d_pose(ts_2d, ax1)
ax1.invert_yaxis()

utils.show_3d_pose(ts_3d, ax2)

plt.show()