diff --git a/app.py b/app.py index c22a471..85a7a17 100644 --- a/app.py +++ b/app.py @@ -17,6 +17,8 @@ import openpose_gen as opg from comfy_socket import get_images from postprocessing import expo_shuffle_image_steps, expo_add_to_background_image, expo_postprocess_main import skeleton_lib as skel +import predict as pred + sys.path.append('./') app = Flask(__name__) @@ -79,6 +81,10 @@ def get_predicted_coordinates(coordinates: list, width: int, height: int) -> lis # when testing, can visualize with the method expo_save_bodypose in openpose_gen.py + predicted = pred.predict_pose_keypoints(np.array(coordinates).reshape(1, 18, 3)) + predicted[:, :, 3] = 1 + return predicted.flatten().tolist() + # for now, just mirror the coordinates and add some random deviation predicted_coordinates = mirror_coordinates(coordinates, width) for i in range(0, len(predicted_coordinates), 3): @@ -248,4 +254,4 @@ def gen_group_pic_prompt(openpose_image_path, base_image, pid, comfyUI_address): if __name__ == '__main__': expo_postprocess() - # app.run(debug=True) \ No newline at end of file + # app.run(debug=True) diff --git a/models/loss8.517782751325285.pt b/models/loss8.517782751325285.pt new file mode 100644 index 0000000..d846885 Binary files /dev/null and b/models/loss8.517782751325285.pt differ diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..d91cda1 --- /dev/null +++ b/predict.py @@ -0,0 +1,140 @@ +# import cv2 +from glob import glob +import matplotlib.pyplot as plt +import numpy as np +import os +import torch + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +use_amp = True +model_path = './models/loss8.517782751325285.pt' + +# define helper functions + +def load_dataset_from_npy(dir_path): + npy_files = glob(os.path.join(os.path.realpath(dir_path), '*.npy')) + return npy_files + +def find_bbox(keypoints): + keypoints_copy = keypoints + to_delete = [] + cnt = 0 + for kp in keypoints_copy: + # print(kp.shape)x + pos = (kp[0], kp[1]) + if (pos == (0, 0)) or (pos == (1, 0)): + to_delete.append(cnt) + cnt += 1 + keypoints_copy = np.delete(keypoints_copy, to_delete, 0) + return [min(keypoints_copy[:, 0]), max(keypoints_copy[:, 0]), min(keypoints_copy[:, 1]), max(keypoints_copy[:, 1])] + +def get_dist_grid(dist_ref, grid_dim=[30, 1], offsets=[0, 0]): + dist_grid = torch.zeros([grid_dim[0], grid_dim[1], 2]).to(device) + offsetX = torch.tensor([offsets[0], 0.0]).float().to(device) + offsetY = torch.tensor([0.0, offsets[1]]).float().to(device) + for i in range(grid_dim[0]): + for j in range(grid_dim[1]): + dist_grid[i, j, :] = dist_ref + \ + offsetX * (i - int((grid_dim[0]) / 2)) + \ + offsetY * (j - int((grid_dim[1]) / 2)) + + return dist_grid + +def bbox_center(bbox): + return ((bbox[0] + bbox[1])/2, (bbox[2] + bbox[3])/2) + +def bbox_dists(bbox1, bbox2): + return (np.array(bbox_center(bbox1)) - np.array(bbox_center(bbox2))) + +def openpose17_colors(): + return ['#ff0000', '#ff5500', '#ffaa00', '#ffff00', '#aaff00', '#55ff00', '#00ff00', '#00ff55', '#00ffaa', '#00ffff', '#00aaff', '#0055ff', '#0000ff', '#5500ff', '#aa00ff', '#ff00ff', '#ff00aa', '#ff0055'] + +def keypoints25to18(keypoints): + return np.delete(keypoints, [8, 19, 20, 21, 22, 23, 24], 0) + +def get_keypoints_linkage(): + return np.array([[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13], [0, 14], [14, 16], [0, 15], [15, 17]]) + +@torch.compile +def batch_pose_confidence_mat(batch_pose_keypoints_with_confidence): + keypoints_conf = batch_pose_keypoints_with_confidence[:, :, 2] + confidence_mat = torch.zeros([keypoints_conf.shape[0], get_keypoints_linkage().shape[0]]).to(device) + + for i in range(get_keypoints_linkage().shape[0]): + for j in range(keypoints_conf.shape[0]): + if keypoints_conf[j, get_keypoints_linkage()[i, 0]] == 0 or keypoints_conf[j, get_keypoints_linkage()[i, 1]] == 0: + confidence_mat[j, i] = 0 + else: + confidence_mat[j, i] = (keypoints_conf[j, get_keypoints_linkage()[i, 0]] + keypoints_conf[j, get_keypoints_linkage()[i, 1]]) / 2 + + return confidence_mat + +@torch.compile +def pose_diff(output_pose_keypoints, target_pose_keypoints, confidence_mat): + link_open = get_keypoints_linkage()[:, 0] + link_end = get_keypoints_linkage()[:, 1] + + p1 = (output_pose_keypoints[:, link_open, :2] - target_pose_keypoints[:, link_open, :2]).reshape(output_pose_keypoints.shape[0], get_keypoints_linkage().shape[0], 2) + p2 = (output_pose_keypoints[:, link_end, :2] - target_pose_keypoints[:, link_end, :2]).reshape(output_pose_keypoints.shape[0], get_keypoints_linkage().shape[0], 2) + + return torch.sum(torch.sum(torch.pow(p1, 2) + torch.pow(p2, 2), axis=2) * confidence_mat) / get_keypoints_linkage().shape[0] / output_pose_keypoints.shape[0] + +@torch.compile +def pose_loss(outputs, batch_target_pose_keypoints_with_confidence): + err = pose_diff(outputs, batch_target_pose_keypoints_with_confidence, batch_pose_confidence_mat(batch_target_pose_keypoints_with_confidence)) + return torch.abs(torch.sum(err)) + +@torch.compile +def zscore_normalization(data): + return torch.std(data), torch.mean(data), (data - torch.mean(data)) / torch.std(data) + +model = torch.nn.Sequential( + torch.nn.Linear(36, 256), + torch.nn.Tanh(), + torch.nn.Dropout(0.1), + torch.nn.Linear(256, 512), + torch.nn.Tanh(), + torch.nn.Dropout(0.1), + torch.nn.Linear(512, 256), + torch.nn.Tanh(), + torch.nn.Dropout(0.1), + torch.nn.Linear(256, 36) +).to(device) + +loss_fn = pose_loss #torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + +checkpoint = torch.load(model_path) +model.load_state_dict(checkpoint['model']) +optimizer.load_state_dict(checkpoint['optimizer']) +scaler.load_state_dict(checkpoint['scaler']) + +sample_data = torch.Tensor([[[-0.9695, -1.6531, 2.2570], + [-0.9758, -1.5557, 2.5996], + [-1.0910, -1.5669, 2.2916], + [-1.1820, -1.3080, 2.4095], + [-1.0606, -1.2970, 2.6237], + [-0.8728, -1.5446, 2.4116], + [-0.7996, -1.2856, 2.2992], + [-0.6417, -1.3074, 2.2848], + [-1.1578, -1.0483, 2.3292], + [-1.2732, -0.6165, 2.3635], + [-1.3583, -0.2720, 2.3981], + [-1.0120, -1.0378, 2.3269], + [-0.8237, -0.7680, 2.5688], + [-0.7751, -0.3148, 2.4276], + [-0.9878, -1.7177, 2.1040], + [-0.9453, -1.7068, 1.8512], + [-1.0184, -1.7280, 1.7790], + [-0.9146, -1.6959, 0.9578]]]).to(device) +model.eval() + + +def predict_pose_keypoints(data): + std, mean, data = zscore_normalization(data) + data = data[:,:,:2].reshape(1, 36).to(device) + with torch.cuda.amp.autocast(enabled=use_amp): + outputs = model(data) + outputs = (outputs * std + mean).reshape(18, 2).cpu().detach().numpy() + return outputs