258 lines
10 KiB
Python
258 lines
10 KiB
Python
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import random
|
|
import uuid
|
|
import cv2
|
|
from flask import Flask, request, jsonify
|
|
import sys
|
|
import os
|
|
from PIL import Image
|
|
import io
|
|
|
|
import numpy as np
|
|
import websocket
|
|
import openpose_gen as opg
|
|
from comfy_socket import get_images
|
|
from postprocessing import expo_shuffle_image_steps, expo_add_to_background_image, expo_postprocess_main
|
|
import skeleton_lib as skel
|
|
import predict as pred
|
|
|
|
sys.path.append('./')
|
|
app = Flask(__name__)
|
|
|
|
info = json.load(open('info.json'))
|
|
|
|
comfyui_address = info['comfyui_address']
|
|
expo_raw_sd_dir = info['expo_raw_sd_dir']
|
|
expo_openpose_dir = info['expo_openpose_dir']
|
|
|
|
on_postprocessing = False
|
|
on_testing = False
|
|
|
|
@app.route('/expo_fencing_pose', methods=['POST'])
|
|
def expo_fencing_pose():
|
|
if on_postprocessing:
|
|
return jsonify({"status": "error", "message": "Postprocessing in progress"}), 503
|
|
|
|
if request.is_json:
|
|
data = request.get_json()
|
|
coordinates = data['coordinates']
|
|
canvas_size = data['canvas_size']
|
|
batch = data['batch']
|
|
step = data['step']
|
|
|
|
if coordinates is None or canvas_size is None or 'batch' not in data or 'step' not in data:
|
|
return jsonify({"status": "error", "message": "Missing data"}), 422
|
|
|
|
right_fencer_coordinates = get_predicted_coordinates(coordinates, canvas_size[0], canvas_size[1])
|
|
|
|
left_fencer_dir = os.path.join(expo_openpose_dir, 'left_fencer')
|
|
os.makedirs(left_fencer_dir, exist_ok=True)
|
|
right_fencer_dir = os.path.join(expo_openpose_dir, 'right_fencer')
|
|
os.makedirs(right_fencer_dir, exist_ok=True)
|
|
|
|
left_openpose_image_path = opg.expo_save_bodypose(canvas_size[0], canvas_size[1], coordinates, batch, step, left_fencer_dir, skel.coco_limbSeq, skel.coco_colors)
|
|
right_openpose_image_path = opg.expo_save_bodypose(canvas_size[0], canvas_size[1], right_fencer_coordinates, batch, step, right_fencer_dir, skel.coco_limbSeq, skel.coco_colors)
|
|
|
|
left_fencer_raw_image_dir = os.path.join(expo_raw_sd_dir, 'left_fencer')
|
|
os.makedirs(left_fencer_raw_image_dir, exist_ok=True)
|
|
right_fencer_raw_image_dir = os.path.join(expo_raw_sd_dir, 'right_fencer')
|
|
os.makedirs(right_fencer_raw_image_dir, exist_ok=True)
|
|
|
|
expo_fencer_prompt(left_openpose_image_path, left_fencer_raw_image_dir, batch, step)
|
|
expo_fencer_prompt(right_openpose_image_path, right_fencer_raw_image_dir, batch, step)
|
|
|
|
return jsonify({"status": "success", "message": "Data received"}), 201
|
|
else:
|
|
return jsonify({"status": "error", "message": "Request must be JSON"}), 415
|
|
|
|
|
|
def get_predicted_coordinates(coordinates: list, width: int, height: int) -> list:
|
|
# TODO implement the model to predict the right fencer's coordinates
|
|
# coordinates = [x1, y1, c1, x2, y2, c2, ...],
|
|
# where x, y are the coordinates and c is the confidence score
|
|
# there should be 18 keypoints from 0 to 17
|
|
# they are not normalized, they are by the size of the width and height
|
|
|
|
# the the limbSeq and colors of points need to convert from and to skel.coco_limbSeq, skel.coco_colors
|
|
# those are in skeleton_lib.py
|
|
|
|
# when testing, can visualize with the method expo_save_bodypose in openpose_gen.py
|
|
|
|
predicted = pred.predict_pose_keypoints(np.array(coordinates).reshape(1, 18, 3))
|
|
predicted[:, :, 3] = 1
|
|
return predicted.flatten().tolist()
|
|
|
|
# for now, just mirror the coordinates and add some random deviation
|
|
predicted_coordinates = mirror_coordinates(coordinates, width)
|
|
for i in range(0, len(predicted_coordinates), 3):
|
|
deviation = random.randint(-10, 10)
|
|
predicted_coordinates[i] += deviation
|
|
deviation = random.randint(-10, 10)
|
|
predicted_coordinates[i + 1] += deviation
|
|
|
|
return predicted_coordinates
|
|
|
|
def mirror_coordinates(coordinates: list, width: int) -> list:
|
|
for i in range(0, len(coordinates), 3):
|
|
coordinates[i] = width - coordinates[i]
|
|
return coordinates
|
|
|
|
def expo_fencer_prompt(openpose_image_path, save_dir, batch, step):
|
|
|
|
prompt = json.loads(open("./prompts/fencer_03.json", "r", encoding="utf-8").read())
|
|
|
|
openpose_image_name = opg.upload_image(openpose_image_path)
|
|
opg.upload_image("./images/ref_black.png", "ref_black.png")
|
|
|
|
print(openpose_image_name)
|
|
|
|
prompt["3"]["inputs"]["seed"] = random.randint(0, 10000000000)
|
|
prompt["29"]["inputs"]['image'] = "ref_black.png"
|
|
prompt["17"]["inputs"]['image'] = openpose_image_name
|
|
|
|
client_id = hashlib.sha256(str(random.getrandbits(256)).encode('utf-8')).hexdigest()
|
|
ws = websocket.WebSocket()
|
|
ws.connect("ws://{}/ws?clientId={}".format(comfyui_address, client_id))
|
|
images = get_images(ws, prompt, client_id)
|
|
for node_id in images:
|
|
for idx, image_data in enumerate(images[node_id]):
|
|
image = Image.open(io.BytesIO(image_data))
|
|
image_path = os.path.join(save_dir, f"{batch}_{step}.png")
|
|
image.save(image_path)
|
|
|
|
def expo_clear_images():
|
|
if on_testing:
|
|
return
|
|
for root, dirs, files in os.walk(expo_openpose_dir):
|
|
for file in files:
|
|
os.remove(os.path.join(root, file))
|
|
for root, dirs, files in os.walk(expo_raw_sd_dir):
|
|
for file in files:
|
|
os.remove(os.path.join(root, file))
|
|
|
|
@app.route('/expo_postprocess', methods=['POST'])
|
|
async def expo_postprocess():
|
|
global on_postprocessing
|
|
if on_postprocessing:
|
|
return jsonify({"status": "error", "message": "Postprocessing in progress"}), 503
|
|
|
|
on_postprocessing = True
|
|
print("Postprocessing")
|
|
|
|
# Wait until the directories have the same files or timeout
|
|
if not await wait_for_files_to_match(expo_openpose_dir, expo_raw_sd_dir):
|
|
print("Timeout reached, proceeding with postprocessing")
|
|
|
|
# Check if directories exist and are not empty
|
|
if not os.path.exists(expo_openpose_dir) or not os.listdir(expo_openpose_dir):
|
|
on_postprocessing = False
|
|
return jsonify({"status": "error", "message": "No images to process in expo_openpose_dir"}), 404
|
|
if not os.path.exists(expo_raw_sd_dir) or not os.listdir(expo_raw_sd_dir):
|
|
on_postprocessing = False
|
|
return jsonify({"status": "error", "message": "No images to process in expo_raw_sd_dir"}), 404
|
|
|
|
await asyncio.to_thread(expo_postprocess_main)
|
|
await asyncio.to_thread(expo_clear_images)
|
|
on_postprocessing = False
|
|
print("Postprocessing completed")
|
|
return jsonify({"status": "success", "message": "Postprocessing completed"}), 200
|
|
|
|
|
|
async def wait_for_files_to_match(dir1: str, dir2: str, timeout: int = 180, interval: int = 1) -> bool:
|
|
start_time = asyncio.get_event_loop().time()
|
|
while asyncio.get_event_loop().time() - start_time < timeout:
|
|
files1 = get_all_files(dir1)
|
|
files2 = get_all_files(dir2)
|
|
if files1 == files2:
|
|
return True
|
|
await asyncio.sleep(interval)
|
|
return False
|
|
|
|
|
|
def get_all_files(directory):
|
|
all_files = set()
|
|
for root, _, files in os.walk(directory):
|
|
for file in files:
|
|
# Store the relative path of the file
|
|
relative_path = os.path.relpath(os.path.join(root, file), directory)
|
|
all_files.add(relative_path)
|
|
return all_files
|
|
|
|
@app.route('/gen_image', methods=['POST'])
|
|
def gen_image():
|
|
if request.is_json:
|
|
data = request.get_json()
|
|
coordinates = data['coordinates']
|
|
canvas_size = data['canvas_size']
|
|
pid = data['pid']
|
|
|
|
if not coordinates or not canvas_size:
|
|
return jsonify({"status": "error", "message": "Missing data"}), 422
|
|
|
|
openpose_image_path = opg.save_bodypose(canvas_size[0], canvas_size[1], coordinates, pid)
|
|
# gen_fencer_prompt(openpose_image_path, pid, comfyui_address)
|
|
|
|
return jsonify({"status": "success", "message": "Data received"}), 201
|
|
else:
|
|
return jsonify({"status": "error", "message": "Request must be JSON"}), 415
|
|
|
|
|
|
@app.route('/gen_group_pic', methods=['POST'])
|
|
def gen_group_pic():
|
|
if request.is_json:
|
|
data = request.get_json()
|
|
coordinates_list = data['coordinates_list']
|
|
canvas_size = data['canvas_size']
|
|
pid = data['pid']
|
|
base_image = base64.b64decode(data['base_image'])
|
|
|
|
if not coordinates_list or not canvas_size or not base_image or not pid:
|
|
return jsonify({"status": "error", "message": "Missing data"}), 422
|
|
|
|
for i in range(len(coordinates_list)):
|
|
coordinates_list[i] = coordinates_list[i]['coordinates']
|
|
|
|
openpose_image_path = opg.save_bodypose_mulit(canvas_size[0], canvas_size[1], coordinates_list, pid)
|
|
gen_group_pic_prompt(openpose_image_path, base_image, pid, comfyui_address)
|
|
|
|
return jsonify({"status": "success", "message": "Data received"}), 201
|
|
else:
|
|
return jsonify({"status": "error", "message": "Request must be JSON"}), 415
|
|
|
|
def gen_fencer_prompt(openpose_image_path, pid, comfyUI_address):
|
|
with open("./prompts/fencerAPI.json", "r") as f:
|
|
prompt_json = f.read()
|
|
prompt = json.loads(prompt_json)
|
|
|
|
openpose_image_name = opg.upload_image_circular_queue(openpose_image_path, 20, pid, comfyUI_address)
|
|
opg.upload_image("./images/ref_black.png", "ref_black.png")
|
|
|
|
prompt["3"]["inputs"]["seed"] = random.randint(0, 10000000000)
|
|
prompt["29"]["inputs"]['image'] = "./images/ref_black.png"
|
|
prompt["17"]["inputs"]['image'] = openpose_image_name
|
|
|
|
opg.queue_prompt(prompt, comfyUI_address)
|
|
|
|
def gen_group_pic_prompt(openpose_image_path, base_image, pid, comfyUI_address):
|
|
with open("./prompts/group_pic.json", "r") as f:
|
|
prompt_json = f.read()
|
|
prompt = json.loads(prompt_json)
|
|
|
|
openpose_image_name = opg.upload_image_circular_queue(openpose_image_path, 30, pid, comfyUI_address)
|
|
base_image_name = opg.upload_image_circular_queue(base_image, 30, pid, comfyUI_address)
|
|
|
|
prompt["3"]["inputs"]["seed"] = random.randint(0, 10000000000)
|
|
prompt["10"]["inputs"]['image'] = openpose_image_name
|
|
prompt["14"]["inputs"]['image'] = base_image_name
|
|
|
|
opg.queue_prompt(prompt, comfyUI_address)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
expo_postprocess()
|
|
# app.run(debug=True)
|