|
@@ -7,26 +7,26 @@ use log::{info, error, debug};
|
|
|
pub enum ImageGenerationError {
|
|
|
#[error("Network error: {0}")]
|
|
|
NetworkError(#[from] reqwest::Error),
|
|
|
- #[error("Invalid API response")]
|
|
|
- InvalidResponse,
|
|
|
+ #[error("Invalid API response: {0}")]
|
|
|
+ InvalidResponse(String),
|
|
|
+ #[error("Missing data in response")]
|
|
|
+ MissingData,
|
|
|
+ #[error("API execution failed: {0}")]
|
|
|
+ ExecutionFailed(String),
|
|
|
}
|
|
|
|
|
|
pub struct ExternalImageGenerator {
|
|
|
api_key: String,
|
|
|
api_url: String,
|
|
|
+ #[allow(dead_code)] // Ajout de cette annotation pour supprimer le warning
|
|
|
model: String,
|
|
|
}
|
|
|
|
|
|
-const SIZE: &str = "1024x1024";
|
|
|
-const NUM_IMAGES: i32 = 1;
|
|
|
-// response_format: "b64_json" or "url"
|
|
|
-const RESPONSE_FORMAT: &str = "url";
|
|
|
-
|
|
|
impl ExternalImageGenerator {
|
|
|
pub fn new(api_key: String, api_url: String, model: String) -> Self {
|
|
|
Self { api_key, api_url, model }
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
pub fn generate_image(&self, style: &str, prompt: &str) -> Result<String, ImageGenerationError> {
|
|
|
debug!("Creating HTTP client for image generation");
|
|
|
let client = Client::new();
|
|
@@ -35,43 +35,202 @@ impl ExternalImageGenerator {
|
|
|
debug!("Formatted prompt: {}", formatted_prompt);
|
|
|
|
|
|
let payload = json!({
|
|
|
- "model": self.model,
|
|
|
- "prompt": formatted_prompt,
|
|
|
- "n": NUM_IMAGES,
|
|
|
- "size": SIZE,
|
|
|
- "response_format": RESPONSE_FORMAT,
|
|
|
+ "input": {
|
|
|
+ "workflow": {
|
|
|
+ "6": {
|
|
|
+ "inputs": {
|
|
|
+ "text": formatted_prompt,
|
|
|
+ "clip": ["11", 0]
|
|
|
+ },
|
|
|
+ "class_type": "CLIPTextEncode",
|
|
|
+ "_meta": {
|
|
|
+ "title": "CLIP Text Encode (Positive Prompt)"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "8": {
|
|
|
+ "inputs": {
|
|
|
+ "samples": ["13", 0],
|
|
|
+ "vae": ["10", 0]
|
|
|
+ },
|
|
|
+ "class_type": "VAEDecode",
|
|
|
+ "_meta": {
|
|
|
+ "title": "VAE Decode"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "9": {
|
|
|
+ "inputs": {
|
|
|
+ "filename_prefix": "ComfyUI",
|
|
|
+ "images": ["8", 0]
|
|
|
+ },
|
|
|
+ "class_type": "SaveImage",
|
|
|
+ "_meta": {
|
|
|
+ "title": "Save Image"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "10": {
|
|
|
+ "inputs": {
|
|
|
+ "vae_name": "ae.safetensors"
|
|
|
+ },
|
|
|
+ "class_type": "VAELoader",
|
|
|
+ "_meta": {
|
|
|
+ "title": "Load VAE"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "11": {
|
|
|
+ "inputs": {
|
|
|
+ "clip_name1": "t5xxl_fp8_e4m3fn.safetensors",
|
|
|
+ "clip_name2": "clip_l.safetensors",
|
|
|
+ "type": "flux"
|
|
|
+ },
|
|
|
+ "class_type": "DualCLIPLoader",
|
|
|
+ "_meta": {
|
|
|
+ "title": "DualCLIPLoader"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "12": {
|
|
|
+ "inputs": {
|
|
|
+ "unet_name": "flux1-dev.safetensors",
|
|
|
+ "weight_dtype": "default"
|
|
|
+ },
|
|
|
+ "class_type": "UNETLoader",
|
|
|
+ "_meta": {
|
|
|
+ "title": "Load Diffusion Model"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "13": {
|
|
|
+ "inputs": {
|
|
|
+ "noise": ["25", 0],
|
|
|
+ "guider": ["22", 0],
|
|
|
+ "sampler": ["16", 0],
|
|
|
+ "sigmas": ["17", 0],
|
|
|
+ "latent_image": ["27", 0]
|
|
|
+ },
|
|
|
+ "class_type": "SamplerCustomAdvanced",
|
|
|
+ "_meta": {
|
|
|
+ "title": "SamplerCustomAdvanced"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "16": {
|
|
|
+ "inputs": {
|
|
|
+ "sampler_name": "euler"
|
|
|
+ },
|
|
|
+ "class_type": "KSamplerSelect",
|
|
|
+ "_meta": {
|
|
|
+ "title": "KSamplerSelect"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "17": {
|
|
|
+ "inputs": {
|
|
|
+ "scheduler": "simple",
|
|
|
+ "steps": 20,
|
|
|
+ "denoise": 1,
|
|
|
+ "model": ["30", 0]
|
|
|
+ },
|
|
|
+ "class_type": "BasicScheduler",
|
|
|
+ "_meta": {
|
|
|
+ "title": "BasicScheduler"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "22": {
|
|
|
+ "inputs": {
|
|
|
+ "model": ["30", 0],
|
|
|
+ "conditioning": ["26", 0]
|
|
|
+ },
|
|
|
+ "class_type": "BasicGuider",
|
|
|
+ "_meta": {
|
|
|
+ "title": "BasicGuider"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "25": {
|
|
|
+ "inputs": {
|
|
|
+ "noise_seed": 219670278747233i64 // Using i64 instead of i32
|
|
|
+ },
|
|
|
+ "class_type": "RandomNoise",
|
|
|
+ "_meta": {
|
|
|
+ "title": "RandomNoise"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "26": {
|
|
|
+ "inputs": {
|
|
|
+ "guidance": 3.5,
|
|
|
+ "conditioning": ["6", 0]
|
|
|
+ },
|
|
|
+ "class_type": "FluxGuidance",
|
|
|
+ "_meta": {
|
|
|
+ "title": "FluxGuidance"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "27": {
|
|
|
+ "inputs": {
|
|
|
+ "width": 1024,
|
|
|
+ "height": 1024,
|
|
|
+ "batch_size": 1
|
|
|
+ },
|
|
|
+ "class_type": "EmptySD3LatentImage",
|
|
|
+ "_meta": {
|
|
|
+ "title": "EmptySD3LatentImage"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "30": {
|
|
|
+ "inputs": {
|
|
|
+ "max_shift": 1.15,
|
|
|
+ "base_shift": 0.5,
|
|
|
+ "width": 1024,
|
|
|
+ "height": 1024,
|
|
|
+ "model": ["12", 0]
|
|
|
+ },
|
|
|
+ "class_type": "ModelSamplingFlux",
|
|
|
+ "_meta": {
|
|
|
+ "title": "ModelSamplingFlux"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
});
|
|
|
|
|
|
info!("Sending image generation request to external API");
|
|
|
- let response = match client.post(&self.api_url)
|
|
|
+ let response = client.post(&self.api_url)
|
|
|
.header("Content-Type", "application/json")
|
|
|
- .bearer_auth(self.api_key.clone())
|
|
|
+ .bearer_auth(&self.api_key)
|
|
|
.json(&payload)
|
|
|
.timeout(std::time::Duration::from_secs(60))
|
|
|
- .send() {
|
|
|
- Ok(resp) => resp,
|
|
|
- Err(e) => {
|
|
|
- error!("Network error during image generation: {}", e);
|
|
|
- return Err(ImageGenerationError::NetworkError(e));
|
|
|
- }
|
|
|
- };
|
|
|
+ .send()
|
|
|
+ .map_err(ImageGenerationError::NetworkError)?;
|
|
|
|
|
|
- if response.status().is_success() {
|
|
|
- debug!("Successfully received response from image generation API");
|
|
|
- match response.json::<serde_json::Value>() {
|
|
|
- Ok(json) => {
|
|
|
- let image_url = json["data"][0][RESPONSE_FORMAT].as_str().unwrap();
|
|
|
- info!("Image generated successfully: {}", image_url);
|
|
|
- Ok(image_url.to_string())
|
|
|
- },
|
|
|
- Err(e) => {
|
|
|
- error!("Failed to parse API response: {}", e);
|
|
|
- Err(ImageGenerationError::NetworkError(e))
|
|
|
+ let status = response.status();
|
|
|
+ if !status.is_success() {
|
|
|
+ let error_text = response.text()
|
|
|
+ .unwrap_or_else(|_| "Failed to get error message".to_string());
|
|
|
+ error!("API returned error status: {}, error: {}", status, error_text);
|
|
|
+ return Err(ImageGenerationError::InvalidResponse(error_text));
|
|
|
+ }
|
|
|
+
|
|
|
+ debug!("Successfully received response from image generation API");
|
|
|
+ let json_response = response.json::<serde_json::Value>()
|
|
|
+ .map_err(|e| ImageGenerationError::NetworkError(e))?;
|
|
|
+
|
|
|
+ // Check if the status is COMPLETED
|
|
|
+ match json_response["status"].as_str() {
|
|
|
+ Some("COMPLETED") => {
|
|
|
+ // Check if the output status is success
|
|
|
+ match json_response["output"]["status"].as_str() {
|
|
|
+ Some("success") => {
|
|
|
+ // Get the base64 message
|
|
|
+ json_response["output"]["message"]
|
|
|
+ .as_str()
|
|
|
+ .map(String::from)
|
|
|
+ .ok_or(ImageGenerationError::MissingData)
|
|
|
+ },
|
|
|
+ Some(status) => Err(ImageGenerationError::ExecutionFailed(
|
|
|
+ format!("Output status was not successful: {}", status)
|
|
|
+ )),
|
|
|
+ None => Err(ImageGenerationError::MissingData)
|
|
|
}
|
|
|
- }
|
|
|
- } else {
|
|
|
- error!("API returned error status: {}, error: {:?}", response.status(), response.text());
|
|
|
- Err(ImageGenerationError::InvalidResponse)
|
|
|
+ },
|
|
|
+ Some(status) => Err(ImageGenerationError::ExecutionFailed(
|
|
|
+ format!("Job status was not COMPLETED: {}", status)
|
|
|
+ )),
|
|
|
+ None => Err(ImageGenerationError::MissingData)
|
|
|
}
|
|
|
}
|
|
|
}
|