pip install flask websocket-client requests
import json
import uuid
import websocket
import urllib.request
import urllib.parse
import random
import threading
import time
import requests
import io
from flask import Flask, request, jsonify, Response

app = Flask(__name__)

# ================= 基础配置 =================
COMFY_ADDR = "127.0.0.1:8188"
CLIENT_ID = str(uuid.uuid4())
gpu_lock = threading.Lock()

def queue_prompt(prompt_workflow):
    p = {"prompt": prompt_workflow, "client_id": CLIENT_ID}
    data = json.dumps(p).encode('utf-8')
    req = urllib.request.Request(f"http://{COMFY_ADDR}/prompt", data=data)
    return json.loads(urllib.request.urlopen(req).read())

def upload_image_to_comfyui(file_bytes, filename):
    """直接发送字节流到 ComfyUI"""
    url = f"http://{COMFY_ADDR}/upload/image"
    files = {'image': (filename, file_bytes)}
    res = requests.post(url, files=files)
    return res.json()['name']

# ================= 前端页面路由 =================
@app.route('/')
def index():
    return '''
    <!DOCTYPE html>
    <html lang="zh-CN">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
        <title>Z-Image AI 创作中心</title>
        <style>
            body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; padding: 15px; margin: 0; display: flex; justify-content: center; }
            .container { background: white; padding: 25px; border-radius: 12px; box-shadow: 0 4px 15px rgba(0,0,0,0.05); width: 100%; max-width: 600px; box-sizing: border-box; }
            h2 { text-align: center; margin-top: 0; color: #333; }
            .tabs { display: flex; margin-bottom: 20px; border-bottom: 2px solid #eee; }
            .tab { flex: 1; text-align: center; padding: 10px; cursor: pointer; font-weight: bold; color: #888; transition: 0.3s; }
            .tab.active { color: #007bff; border-bottom: 2px solid #007bff; }
            .input-group { margin-bottom: 15px; text-align: left; }
            .input-group label { display: block; margin-bottom: 5px; font-size: 14px; color: #555; }
            input[type="text"], input[type="file"] { width: 100%; padding: 12px; font-size: 15px; border: 1px solid #ccc; border-radius: 8px; box-sizing: border-box; }
            button { width: 100%; padding: 14px; font-size: 16px; background-color: #007bff; color: white; border: none; border-radius: 8px; cursor: pointer; transition: 0.3s; font-weight: bold; }
            #progressContainer { display: none; margin-top: 15px; }
            .progress-bar-bg { width: 100%; background-color: #e9ecef; border-radius: 8px; overflow: hidden; height: 12px; }
            .progress-bar-fill { height: 100%; background-color: #28a745; width: 0%; transition: width 0.2s; }
            #status { margin: 15px 0 5px 0; font-weight: bold; color: #555; text-align: center; font-size: 14px;}
            #timeInfo { font-size: 12px; color: #999; text-align: center; margin-bottom: 10px; display: none; }
            #resultImg { width: 100%; border-radius: 8px; margin-top: 10px; display: none; box-shadow: 0 2px 8px rgba(0,0,0,0.2); }
        </style>
    </head>
    <body>
        <div class="container">
            <h2>✨ AI 创作中心</h2>
            <div class="tabs">
                <div class="tab active" onclick="switchMode('t2i')" id="tab-t2i">文生图</div>
                <div class="tab" onclick="switchMode('i2i')" id="tab-i2i">图改图</div>
            </div>
            <form id="genForm">
                <div class="input-group" id="file-upload-group" style="display: none;">
                    <label>上传参考图:</label>
                    <input type="file" id="uploadImage" accept="image/*">
                </div>
                <div class="input-group">
                    <label id="prompt-label">画面描述:</label>
                    <input type="text" id="prompt" value="a beautiful girl, cinematic lighting, 8k">
                </div>
                <button type="button" id="genBtn" onclick="generate()">立即生成</button>
            </form>
            <p id="status"></p>
            <div id="timeInfo"></div>
            <div id="progressContainer">
                <div class="progress-bar-bg"><div class="progress-bar-fill" id="progressFill"></div></div>
            </div>
            <img id="resultImg" src="">
        </div>

        <script>
            let currentMode = 't2i';
            function switchMode(mode) {
                currentMode = mode;
                document.getElementById('tab-t2i').className = mode === 't2i' ? 'tab active' : 'tab';
                document.getElementById('tab-i2i').className = mode === 'i2i' ? 'tab active' : 'tab';
                document.getElementById('file-upload-group').style.display = mode === 'i2i' ? 'block' : 'none';
                document.getElementById('prompt').value = mode === 'i2i' ? '猫换成狗' : 'a beautiful girl, cinematic lighting, 8k';
            }

            async function generate() {
                const btn = document.getElementById('genBtn');
                const status = document.getElementById('status');
                const timeInfo = document.getElementById('timeInfo');
                const img = document.getElementById('resultImg');
                const progressFill = document.getElementById('progressFill');
                const progressContainer = document.getElementById('progressContainer');
                
                btn.disabled = true;
                status.innerText = "🚀 正在连接 AI 服务器...";
                timeInfo.style.display = "none";
                img.style.display = "none";
                progressContainer.style.display = "block";
                progressFill.style.width = "0%";

                const formData = new FormData();
                formData.append('mode', currentMode);
                formData.append('prompt', document.getElementById('prompt').value);
                if(currentMode === 'i2i') formData.append('image', document.getElementById('uploadImage').files[0]);

                try {
                    const response = await fetch('/generate', { method: 'POST', body: formData });
                    const reader = response.body.getReader();
                    const decoder = new TextDecoder("utf-8");
                    let buffer = "";
                    let startTime = 0;

                    while (true) {
                        const { done, value } = await reader.read();
                        if (done) break;
                        buffer += decoder.decode(value, { stream: true });
                        let lines = buffer.split('\\n');
                        buffer = lines.pop(); 
                        
                        for (let line of lines) {
                            if (!line.trim()) continue;
                            const data = JSON.parse(line);
                            if (data.status === "progress") {
                                if (startTime === 0) startTime = Date.now();
                                let percent = Math.round((data.value / data.max) * 100);
                                progressFill.style.width = percent + "%";
                                status.innerText = `🔄 正在绘制: ${percent}%`;
                            } else if (data.status === "success") {
                                status.innerText = "🎉 生成完毕";
                                timeInfo.innerText = `模型计算耗时: ${data.elapsed_time}s`;
                                timeInfo.style.display = "block";
                                img.src = data.image_url + "?t=" + new Date().getTime(); 
                                img.style.display = "block";
                                progressContainer.style.display = "none";
                            }
                        }
                    }
                } catch (e) { status.innerText = "❌ 出错啦"; }
                finally { btn.disabled = false; }
            }
        </script>
    </body>
    </html>
# ====3个 '''

@app.route('/api/image/<filename>')
def get_image(filename):
    url = f"http://{COMFY_ADDR}/view?filename={urllib.parse.quote(filename)}"
    try:
        with urllib.request.urlopen(url) as response:
            return Response(response.read(), mimetype='image/png')
    except: return "Image Not Found", 404

# ================= 后端逻辑优化 =================
@app.route('/generate', methods=['POST'])
def generate():
    # 【修复核心】在主函数中预先提取所有 request 数据
    mode = request.form.get("mode", "t2i")
    prompt = request.form.get("prompt", "")
    image_bytes = None
    image_filename = None

    if mode == 'i2i' and 'image' in request.files:
        file = request.files['image']
        image_bytes = file.read() # 将文件读入内存
        image_filename = file.filename

    def stream_generation(m, p, img_b, img_n):
        with gpu_lock:
            try:
                start_clock = time.time()
                # --- 工作流配置 ---
                if m == 'i2i':
                    comfy_name = upload_image_to_comfyui(img_b, img_n)
                    with open("FireRed-Image-Edit.json", "r", encoding="utf-8") as f:
                        wf = json.load(f)
                    wf["41"]["inputs"]["image"] = comfy_name
                    wf["68"]["inputs"]["prompt"] = p
                    wf["65"]["inputs"]["seed"] = random.randint(10**10, 10**15)
                    node_id = "9"
                else:
                    with open("Z-Image-GGUF.json", "r", encoding="utf-8") as f:
                        wf = json.load(f)
                    wf["8"]["inputs"]["text"] = p
                    wf["10"]["inputs"]["seed"] = random.randint(10**10, 10**15)
                    node_id = "18"

                # --- 任务推送 ---
                ws = websocket.WebSocket()
                ws.connect(f"ws://{COMFY_ADDR}/ws?clientId={CLIENT_ID}")
                prompt_id = queue_prompt(wf)['prompt_id']

                while True:
                    out = ws.recv()
                    if isinstance(out, str):
                        msg = json.loads(out)
                        if msg['type'] == 'progress':
                            yield json.dumps({"status": "progress", "value": msg['data']['value'], "max": msg['data']['max']}) + "\n"
                        elif msg['type'] == 'executing' and msg['data']['node'] is None and msg['data']['prompt_id'] == prompt_id:
                            break
                ws.close()

                # --- 耗时计算与结果 ---
                elapsed = round(time.time() - start_clock, 2)
                with urllib.request.urlopen(f"http://{COMFY_ADDR}/history/{prompt_id}") as r:
                    hist = json.loads(r.read())
                
                fname = hist[prompt_id]['outputs'][node_id]['images'][0]['filename']
                yield json.dumps({"status": "success", "image_url": f"/api/image/{fname}", "elapsed_time": elapsed}) + "\n"
                
            except Exception as e:
                yield json.dumps({"status": "error", "message": str(e)}) + "\n"

    # 将提取出的数据传给生成器
    return Response(stream_generation(mode, prompt, image_bytes, image_filename), mimetype='application/x-ndjson')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)
❤️ 转载文章请注明出处,谢谢!❤️