loic boulet 7 miesięcy temu
rodzic
commit
6319663248
1 zmienionych plików z 196 dodań i 37 usunięć
  1. 196 37
      src/external_image_generator.rs

+ 196 - 37
src/external_image_generator.rs

@@ -7,26 +7,26 @@ use log::{info, error, debug};
 pub enum ImageGenerationError {
     #[error("Network error: {0}")]
     NetworkError(#[from] reqwest::Error),
-    #[error("Invalid API response")]
-    InvalidResponse,
+    #[error("Invalid API response: {0}")]
+    InvalidResponse(String),
+    #[error("Missing data in response")]
+    MissingData,
+    #[error("API execution failed: {0}")]
+    ExecutionFailed(String),
 }
 
 pub struct ExternalImageGenerator {
     api_key: String,
     api_url: String,
+    #[allow(dead_code)]  // Ajout de cette annotation pour supprimer le warning
     model: String,
 }
 
-const SIZE: &str = "1024x1024";
-const NUM_IMAGES: i32 = 1;
-// response_format: "b64_json" or "url"
-const RESPONSE_FORMAT: &str = "url";
-
 impl ExternalImageGenerator {
     pub fn new(api_key: String, api_url: String, model: String) -> Self {
         Self { api_key, api_url, model }
     }
-
+    
     pub fn generate_image(&self, style: &str, prompt: &str) -> Result<String, ImageGenerationError> {
         debug!("Creating HTTP client for image generation");
         let client = Client::new();
@@ -35,43 +35,202 @@ impl ExternalImageGenerator {
         debug!("Formatted prompt: {}", formatted_prompt);
 
         let payload = json!({
-            "model": self.model,
-            "prompt": formatted_prompt,
-            "n": NUM_IMAGES,
-            "size": SIZE,
-            "response_format": RESPONSE_FORMAT,
+            "input": {
+                "workflow": {
+                    "6": {
+                        "inputs": {
+                            "text": formatted_prompt,
+                            "clip": ["11", 0]
+                        },
+                        "class_type": "CLIPTextEncode",
+                        "_meta": {
+                            "title": "CLIP Text Encode (Positive Prompt)"
+                        }
+                    },
+                    "8": {
+                        "inputs": {
+                            "samples": ["13", 0],
+                            "vae": ["10", 0]
+                        },
+                        "class_type": "VAEDecode",
+                        "_meta": {
+                            "title": "VAE Decode"
+                        }
+                    },
+                    "9": {
+                        "inputs": {
+                            "filename_prefix": "ComfyUI",
+                            "images": ["8", 0]
+                        },
+                        "class_type": "SaveImage",
+                        "_meta": {
+                            "title": "Save Image"
+                        }
+                    },
+                    "10": {
+                        "inputs": {
+                            "vae_name": "ae.safetensors"
+                        },
+                        "class_type": "VAELoader",
+                        "_meta": {
+                            "title": "Load VAE"
+                        }
+                    },
+                    "11": {
+                        "inputs": {
+                            "clip_name1": "t5xxl_fp8_e4m3fn.safetensors",
+                            "clip_name2": "clip_l.safetensors",
+                            "type": "flux"
+                        },
+                        "class_type": "DualCLIPLoader",
+                        "_meta": {
+                            "title": "DualCLIPLoader"
+                        }
+                    },
+                    "12": {
+                        "inputs": {
+                            "unet_name": "flux1-dev.safetensors",
+                            "weight_dtype": "default"
+                        },
+                        "class_type": "UNETLoader",
+                        "_meta": {
+                            "title": "Load Diffusion Model"
+                        }
+                    },
+                    "13": {
+                        "inputs": {
+                            "noise": ["25", 0],
+                            "guider": ["22", 0],
+                            "sampler": ["16", 0],
+                            "sigmas": ["17", 0],
+                            "latent_image": ["27", 0]
+                        },
+                        "class_type": "SamplerCustomAdvanced",
+                        "_meta": {
+                            "title": "SamplerCustomAdvanced"
+                        }
+                    },
+                    "16": {
+                        "inputs": {
+                            "sampler_name": "euler"
+                        },
+                        "class_type": "KSamplerSelect",
+                        "_meta": {
+                            "title": "KSamplerSelect"
+                        }
+                    },
+                    "17": {
+                        "inputs": {
+                            "scheduler": "simple",
+                            "steps": 20,
+                            "denoise": 1,
+                            "model": ["30", 0]
+                        },
+                        "class_type": "BasicScheduler",
+                        "_meta": {
+                            "title": "BasicScheduler"
+                        }
+                    },
+                    "22": {
+                        "inputs": {
+                            "model": ["30", 0],
+                            "conditioning": ["26", 0]
+                        },
+                        "class_type": "BasicGuider",
+                        "_meta": {
+                            "title": "BasicGuider"
+                        }
+                    },
+                    "25": {
+                        "inputs": {
+                            "noise_seed": 219670278747233i64  // Using i64 instead of i32
+                        },
+                        "class_type": "RandomNoise",
+                        "_meta": {
+                            "title": "RandomNoise"
+                        }
+                    },
+                    "26": {
+                        "inputs": {
+                            "guidance": 3.5,
+                            "conditioning": ["6", 0]
+                        },
+                        "class_type": "FluxGuidance",
+                        "_meta": {
+                            "title": "FluxGuidance"
+                        }
+                    },
+                    "27": {
+                        "inputs": {
+                            "width": 1024,
+                            "height": 1024,
+                            "batch_size": 1
+                        },
+                        "class_type": "EmptySD3LatentImage",
+                        "_meta": {
+                            "title": "EmptySD3LatentImage"
+                        }
+                    },
+                    "30": {
+                        "inputs": {
+                            "max_shift": 1.15,
+                            "base_shift": 0.5,
+                            "width": 1024,
+                            "height": 1024,
+                            "model": ["12", 0]
+                        },
+                        "class_type": "ModelSamplingFlux",
+                        "_meta": {
+                            "title": "ModelSamplingFlux"
+                        }
+                    }
+                }
+            }
         });
 
         info!("Sending image generation request to external API");
-        let response = match client.post(&self.api_url)
+        let response = client.post(&self.api_url)
             .header("Content-Type", "application/json")
-            .bearer_auth(self.api_key.clone())
+            .bearer_auth(&self.api_key)
             .json(&payload)
             .timeout(std::time::Duration::from_secs(60))
-            .send() {
-                Ok(resp) => resp,
-                Err(e) => {
-                    error!("Network error during image generation: {}", e);
-                    return Err(ImageGenerationError::NetworkError(e));
-                }
-            };
+            .send()
+            .map_err(ImageGenerationError::NetworkError)?;
 
-        if response.status().is_success() {
-            debug!("Successfully received response from image generation API");
-            match response.json::<serde_json::Value>() {
-                Ok(json) => {
-                    let image_url = json["data"][0][RESPONSE_FORMAT].as_str().unwrap();
-                    info!("Image generated successfully: {}", image_url);
-                    Ok(image_url.to_string())
-                },
-                Err(e) => {
-                    error!("Failed to parse API response: {}", e);
-                    Err(ImageGenerationError::NetworkError(e))
+        let status = response.status();
+        if !status.is_success() {
+            let error_text = response.text()
+                .unwrap_or_else(|_| "Failed to get error message".to_string());
+            error!("API returned error status: {}, error: {}", status, error_text);
+            return Err(ImageGenerationError::InvalidResponse(error_text));
+        }
+
+        debug!("Successfully received response from image generation API");
+        let json_response = response.json::<serde_json::Value>()
+            .map_err(|e| ImageGenerationError::NetworkError(e))?;
+        
+        // Check if the status is COMPLETED
+        match json_response["status"].as_str() {
+            Some("COMPLETED") => {
+                // Check if the output status is success
+                match json_response["output"]["status"].as_str() {
+                    Some("success") => {
+                        // Get the base64 message
+                        json_response["output"]["message"]
+                            .as_str()
+                            .map(String::from)
+                            .ok_or(ImageGenerationError::MissingData)
+                    },
+                    Some(status) => Err(ImageGenerationError::ExecutionFailed(
+                        format!("Output status was not successful: {}", status)
+                    )),
+                    None => Err(ImageGenerationError::MissingData)
                 }
-            }
-        } else {
-            error!("API returned error status: {}, error: {:?}", response.status(), response.text());
-            Err(ImageGenerationError::InvalidResponse)
+            },
+            Some(status) => Err(ImageGenerationError::ExecutionFailed(
+                format!("Job status was not COMPLETED: {}", status)
+            )),
+            None => Err(ImageGenerationError::MissingData)
         }
     }
 }