viz
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from baseline_3d_pose import utils
from pathlib import Path
Path.ls = lambda x: list(x.iterdir())
data_path = Path('data')
json_path = Path('json')
imgs_path = Path('imgs')
stat_3d = torch.load(data_path/'stat_3d.pt')
stat_2d = torch.load(data_path/'stat_2d.pt')
train_set_3d = torch.load(data_path/'train_3d.pt')
test_set_3d = torch.load(data_path/'test_3d.pt')
train_set_2d = torch.load(data_path/'train_2d.pt')
test_set_2d = torch.load(data_path/'test_2d.pt')
rcams = torch.load(data_path/'rcams.pt')
mean_2d = stat_2d['mean']
std_2d = stat_2d['std']
dim_use_2d = stat_2d['dim_use']
dim_ignore_2d = stat_2d['dim_ignore']
mean_3d = stat_3d['mean']
std_3d = stat_3d['std']
dim_use_3d = stat_3d['dim_use']
dim_ignore_3d = stat_3d['dim_ignore']
test_2d = 0
test_3d = 0
for key in test_set_2d.keys():
test_2d += test_set_2d[key].shape[0]
test_3d += test_set_3d[key].shape[0]
print(test_2d, test_3d)
train_2d = 0
train_3d = 0
for key in train_set_2d.keys():
train_2d += train_set_2d[key].shape[0]
train_3d += train_set_3d[key].shape[0]
print(train_2d, train_3d)
train_key_list = list(train_set_2d.keys())
kp = utils.get_kp_from_json(json_path/'keypoints.json') * 3
kps = utils.coco_to_skel(kp)
kps_norm = utils.normalize_kp(kps, mean_2d, std_2d, dim_use_2d)
kps_unnorm = utils.unnormalize_data(kps_norm, mean_2d, std_2d, dim_ignore_2d)
kps_unnorm
plt.figure(figsize=(6,6))
gs = GridSpec(1,1)
gs.update(wspace=0.05, hspace=0.05)
ax = plt.subplot(gs[0])
utils.show_2d_pose(kps_unnorm, ax)
ax.invert_yaxis()
plt.show()
key = train_key_list[184]
plt.figure(figsize=(16,6))
gs2 = GridSpec(1,2)
ax1 = plt.subplot(gs2[0])
ax2 = plt.subplot(gs2[1], projection='3d')
idx = 200
ts_2d = utils.unnormalize_data(train_set_2d[key][idx], mean_2d, std_2d, dim_ignore_2d)[0]
ts_3d = utils.unnormalize_data(train_set_3d[key][idx], mean_3d, std_3d, dim_ignore_3d)[0]
ts_3d = utils.cam_to_world_centered(ts_3d, key, rcams)
utils.show_2d_pose(ts_2d, ax1)
ax1.invert_yaxis()
utils.show_3d_pose(ts_3d, ax2)
plt.show()