111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
|
import json
|
||
|
import concurrent
|
||
|
import pandas as pd
|
||
|
from ultralytics import YOLO
|
||
|
import os
|
||
|
import skeleton_lib as skel
|
||
|
import torch
|
||
|
|
||
|
def point_in_box(point, box):
|
||
|
x, y = point
|
||
|
x1, y1, x2, y2 = box
|
||
|
return x1 <= x <= x2 and y1 <= y <= y2
|
||
|
|
||
|
def load_lerped_keypoints(lerped_keypoints_path):
|
||
|
with open(lerped_keypoints_path, 'r') as f:
|
||
|
return json.load(f)
|
||
|
|
||
|
def get_valid_skeletons(data, data_i, boxes, keypoints):
|
||
|
valid_skeletons = [skel.Skeleton([]) for _ in range(2)]
|
||
|
for avg_i, avg in enumerate(data[data_i]):
|
||
|
for i, box in enumerate(boxes.xyxy.tolist()):
|
||
|
if point_in_box((avg['x'], avg['y']), box):
|
||
|
skeleton = skel.Skeleton([])
|
||
|
for j, keypoint in enumerate(keypoints.xy[i]):
|
||
|
keypoint = keypoint.tolist() + [keypoints.conf[i][j].item()]
|
||
|
skeleton.keypoints.append(skel.Keypoint(keypoint[0], keypoint[1], keypoint[2]))
|
||
|
valid_skeletons[avg_i] = skeleton
|
||
|
break
|
||
|
return valid_skeletons
|
||
|
|
||
|
def get_yoloed_frames(results, lerped_keypoints_path):
|
||
|
frames = []
|
||
|
data = load_lerped_keypoints(lerped_keypoints_path)
|
||
|
for data_i, result in enumerate(results):
|
||
|
boxes = result.boxes # Boxes object for bounding box outputs
|
||
|
keypoints = result.keypoints # Keypoints object for pose outputs
|
||
|
frames.append(get_valid_skeletons(data, data_i, boxes, keypoints))
|
||
|
return frames
|
||
|
|
||
|
def process_clip(row, model):
|
||
|
clip_name = row['ClipName']
|
||
|
input_video_path = f"video_frames/{clip_name}"
|
||
|
lerped_keypoints_path = f"./lerped_keypoints/{clip_name}.json"
|
||
|
output_keypoints_path = f"./new_yolo_keypoints/{clip_name}.json"
|
||
|
|
||
|
# Ensure the folders exist
|
||
|
os.makedirs(os.path.dirname(lerped_keypoints_path), exist_ok=True)
|
||
|
os.makedirs(os.path.dirname(output_keypoints_path), exist_ok=True)
|
||
|
|
||
|
# # return if the file already exists
|
||
|
# if os.path.exists(output_keypoints_path):
|
||
|
# return
|
||
|
|
||
|
results = model(input_video_path)
|
||
|
frames = get_yoloed_frames(results, lerped_keypoints_path)
|
||
|
|
||
|
# Write to JSON file
|
||
|
with open(output_keypoints_path, 'w') as f:
|
||
|
json.dump(frames, f, cls=skel.Encoder, indent=4)
|
||
|
|
||
|
def process_rows_on_gpu(rows, model, device):
|
||
|
for _, row in rows.iterrows():
|
||
|
for _ in range(5):
|
||
|
try:
|
||
|
process_clip(row, model)
|
||
|
except Exception as e:
|
||
|
print(f"Error processing clip: {e}")
|
||
|
del model
|
||
|
model = YOLO("yolo11x-pose.pt").to(device)
|
||
|
continue
|
||
|
break
|
||
|
|
||
|
def gen_yolo_skeletons(descriptor):
|
||
|
num_gpus = torch.cuda.device_count()
|
||
|
|
||
|
rows_per_gpu = len(descriptor) // num_gpus
|
||
|
|
||
|
models = [YOLO("yolo11x-pose.pt").to(torch.device(f'cuda:{i}')) for i in range(num_gpus)]
|
||
|
|
||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_gpus) as executor:
|
||
|
futures = []
|
||
|
for i in range(num_gpus):
|
||
|
start_idx = i * rows_per_gpu
|
||
|
end_idx = (i + 1) * rows_per_gpu if i != num_gpus - 1 else len(descriptor)
|
||
|
gpu_rows = descriptor.iloc[start_idx:end_idx]
|
||
|
futures.append(executor.submit(process_rows_on_gpu, gpu_rows, models[i], torch.device(f'cuda:{i}')))
|
||
|
|
||
|
for future in concurrent.futures.as_completed(futures):
|
||
|
try:
|
||
|
future.result()
|
||
|
except Exception as e:
|
||
|
print(f"Error processing rows on GPU: {e}")
|
||
|
|
||
|
def gen_yolo_skeletons_single(descriptor):
|
||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||
|
model = YOLO("yolo11x-pose.pt").to(device)
|
||
|
|
||
|
process_rows_on_gpu(descriptor, model, device)
|
||
|
|
||
|
def main():
|
||
|
model = YOLO("yolo11x-pose.pt") # pretrained YOLO11n model
|
||
|
descriptor = pd.read_csv('./ClipDescriptorKaggle_processed.csv')
|
||
|
|
||
|
avg_keypoints_folder = './avg_keypoints'
|
||
|
gen_yolo_skeletons(descriptor)
|
||
|
|
||
|
# count number of files in the "./new_yolo_keypoints"
|
||
|
# print(f"Number of files in {"./new_yolo_keypoints"}: {len(os.listdir("./new_yolo_keypoints"))}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|