diff --git a/app/api/generate/[inferenceId]/route.ts b/app/api/generate/[inferenceId]/route.ts index cee04a5..a8c46e1 100644 --- a/app/api/generate/[inferenceId]/route.ts +++ b/app/api/generate/[inferenceId]/route.ts @@ -24,26 +24,41 @@ export const gridSizeToScenarioPixelMap = { } export const uploadImage = async (base64String: string) => { - const base64FileData = base64String.split("base64,")?.[1] + try { + const base64FileData = base64String.split("base64,")?.[1] + if (!base64FileData) { + throw new Error('Invalid base64 string format'); + } - const uuid = uuidv4() - const { data: upload, error } = await supabase.storage - .from("pixelated") - .upload(`${uuid}.png`, decode(base64FileData), { - contentType: "image/png", - cacheControl: "3600", - upsert: false, - }) + const uuid = uuidv4() + const { data: upload, error } = await supabase.storage + .from("pixelated") + .upload(`${uuid}.png`, decode(base64FileData), { + contentType: "image/png", + cacheControl: "3600", + upsert: false, + }) - const { data } = await supabase.storage - .from("pixelated") - .getPublicUrl(`${uuid}.png`) + console.log("upload", upload) - if (error) { - throw new Error(error.message) + if (error) { + console.error('Supabase upload error:', error); + throw new Error(error.message) + } + + const { data: urlData } = await supabase.storage + .from("pixelated") + .getPublicUrl(`${uuid}.png`) + + + console.log("urlData", urlData) + + // Return an object with publicUrl property to match expected structure + return { publicUrl: urlData.publicUrl } + } catch (error) { + console.error('Upload image error:', error); + throw error; } - - return data } type PixelateImageParams = { @@ -63,15 +78,15 @@ export const pixelateImageScenario = async ({ removeBackground = false, }: PixelateImageParams) => { const pixelateResponse = await fetch( - `https://api.cloud.scenario.com/v1/images/pixelate`, + `https://api.cloud.scenario.com/v1/generate/pixelate`, { - method: "PUT", + method: "POST", headers: { "Content-Type": "application/json", Authorization: `Basic ${scenarioAuthToken}`, }, body: JSON.stringify({ - assetId, + image: assetId, pixelGridSize: gridSizeToScenarioPixelMap[pixelGridSize], returnImage: true, removeNoise: true, @@ -84,6 +99,7 @@ export const pixelateImageScenario = async ({ } ) + if (!pixelateResponse.ok) { return pixelateImage({ remoteUrl, @@ -93,7 +109,17 @@ export const pixelateImageScenario = async ({ const pixelateData: ScenarioPixelateResponse = await pixelateResponse.json() - return pixelateData.image + + console.log("pixelate data response", pixelateData) + // @ts-ignore + return pixelateData.asset +} + +async function urlToBase64(url: string) { + console.log(url) + const response = await fetch(url); + const buffer = await response.arrayBuffer(); + return `data:image/png;base64,${Buffer.from(buffer).toString('base64')}`; } export async function GET( @@ -115,128 +141,165 @@ export async function GET( return new Response(null, { status: 403 }) } - // Track the status of our inference progress here - const inferenceProgress: ScenarioInferenceProgressResponse = - await fetch( - `https://api.cloud.scenario.com/v1/models/${modelId}/inferences/${params.inferenceId}`, - { - method: "GET", - headers: { - "Content-Type": "application/json", - Authorization: `Basic ${scenarioAuthToken}`, - }, - } - ).then((res) => res.json()) - - // If the inference was a success, decrement the user's credits. - if (inferenceProgress.inference.status === "succeeded") { - const generation = await db.generation.findUniqueOrThrow({ - where: { - uniqueGeneration: { - inferenceId: params.inferenceId, - modelId: modelId, - }, + // Use the new jobs endpoint instead of inference progress + const jobProgress = await fetch( + `https://api.cloud.scenario.com/v1/jobs/${params.inferenceId}`, + { + method: "GET", + headers: { + "Content-Type": "application/json", + Authorization: `Basic ${scenarioAuthToken}`, }, - include: { - outputImages: true, - }, - }) - - // If the generation is already complete, return the output images - if (generation.status === "COMPLETE") { - let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse = - { - ...inferenceProgress, - outputImages: generation.outputImages, - } - return new Response( - JSON.stringify(copiedInferenceProgressWithImagesPixelated), - { status: 200 } - ) } + ).then((res) => res.json()) - const pixelatedImagesScenario = await Promise.all( - inferenceProgress.inference.images.map((image) => { - if (generation.pixelSize === 32) { - return pixelateImage({ - remoteUrl: image.url, - pixelSize: generation.pixelSize, - }) - } - return pixelateImageScenario({ - remoteUrl: image.url, - assetId: image.id, - pixelGridSize: generation.pixelSize, - colorPaletteEnabled: generation.colorPaletteEnabled, - colors: generation.colors as number[][], - }) - }) - ) + console.log(jobProgress) - const pixelatedImages = await Promise.all( - pixelatedImagesScenario.map((image) => { - return uploadImage(image) - }) - ) - - const imagesWithPixelated = inferenceProgress.inference.images.map( - (image, index) => { - return { - ...image, - pixelated: pixelatedImages[index].publicUrl, - } - } - ) - - await db.$transaction([ - db.user.update({ + // If the job was successful, process the images and update credits + if (jobProgress.job.status === "success") { + console.log("Job success") + // Wrap this in try-catch since findUniqueOrThrow can fail + try { + const generation = await db.generation.findUniqueOrThrow({ where: { - id: session.user.id, - }, - data: { - credits: { - decrement: generation.numSamples / 4, + uniqueGeneration: { + inferenceId: jobProgress.job.metadata.inferenceId, + modelId: modelId, }, }, - }), - db.generation.update({ - where: { - id: generation.id, + include: { + outputImages: true, }, - data: { - status: "COMPLETE", - outputImages: { - createMany: { - data: imagesWithPixelated.map((image) => { - return { - scenarioImageId: image.id, - image: image.url, - seed: image.seed, - pixelatedImage: image.pixelated, - } - }), + }) + + console.log("generateion jere", generation) + + // If already processed, return existing output images + if (generation.status === "COMPLETE") { + return new Response( + JSON.stringify({ + ...jobProgress, + outputImages: generation.outputImages, + }), + { status: 200 } + ) + } + + // Process images using assetIds from job metadata + const assetIds = jobProgress.job.metadata.assetIds + const pixelatedImagesScenario = await Promise.all( + assetIds.map((assetId) => { + if (generation.pixelSize === 32) { + return pixelateImage({ + remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`, + pixelSize: generation.pixelSize, + }) + } + return pixelateImageScenario({ + remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`, + assetId, + pixelGridSize: generation.pixelSize, + colorPaletteEnabled: generation.colorPaletteEnabled, + colors: generation.colors as number[][], + }) + }) + ) + + const pixelatedImages = await Promise.all( + pixelatedImagesScenario.map(async (imageResponse) => { + try { + // If it's a base64 string (from pixelateImage function) + if (typeof imageResponse === 'string') { + return uploadImage(imageResponse); + } + // If it's a response from scenario API + if (imageResponse?.url) { + const base64Data = await urlToBase64(imageResponse.url); + console.log("base64Data", base64Data) + return uploadImage(base64Data); + } + console.error('Invalid imageResponse:', imageResponse); + throw new Error('Invalid image response format'); + } catch (error) { + console.error('Error processing image:', error); + throw error; + } + }) + ) + + console.log("pixelateImages", pixelatedImages) + console.log("jobProgress.job.metadata.assetIds", jobProgress.job.metadata.assetIds) + + const imagesWithPixelated = jobProgress.job.metadata.assetIds.map( + (assetId, index) => { + return { + scenarioImageId: assetId, + seed: "seed", + image: pixelatedImages[index].publicUrl, + pixelated: pixelatedImages[index].publicUrl, + } + } + ) + + console.log("imagesWithPixelated", imagesWithPixelated) + console.log("generation", generation) + + await db.$transaction([ + db.user.update({ + where: { + id: session.user.id, + }, + data: { + credits: { + decrement: generation.numSamples / 4, }, }, + }), + db.generation.update({ + where: { + id: generation.id, + }, + data: { + status: "COMPLETE", + outputImages: { + createMany: { + data: imagesWithPixelated.map((image) => { + return { + scenarioImageId: image.scenarioImageId, + image: image.url, + seed: image.seed, + pixelatedImage: image.pixelated, + } + }), + }, + }, + }, + }), + ]) + + const outputImages = await db.outputImage.findMany({ + where: { + generationId: generation.id, }, - }), - ]) + }) - const outputImages = await db.outputImage.findMany({ - where: { - generationId: generation.id, - }, - }) + console.log("outputImages", outputImages) - let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse = - { - ...inferenceProgress, - outputImages, + return new Response( + JSON.stringify({ + ...jobProgress, + outputImages, + }), + { status: 200 } + ) + } catch (error) { + // If generation record doesn't exist, return just the job progress + if (error.code === 'P2025') { + return new Response(JSON.stringify(jobProgress), { status: 200 }) } - return new Response( - JSON.stringify(copiedInferenceProgressWithImagesPixelated), - { status: 200 } - ) - } else if (inferenceProgress.inference.status === "failed") { + throw error // Re-throw other errors + } + } else if (jobProgress.job.status === "failed") { const generation = await db.generation.findUniqueOrThrow({ where: { uniqueGeneration: { @@ -256,7 +319,7 @@ export async function GET( }) } - return new Response(JSON.stringify(inferenceProgress), { status: 200 }) + return new Response(JSON.stringify(jobProgress), { status: 200 }) } catch (error) { console.log("Error", error) if (error instanceof z.ZodError) { diff --git a/app/api/generate/route.ts b/app/api/generate/route.ts index bdba8c1..18fb389 100644 --- a/app/api/generate/route.ts +++ b/app/api/generate/route.ts @@ -74,38 +74,36 @@ export async function POST(req: Request) { ) } - const generation: ScenarioInferenceResponse = await fetch( - `https://api.cloud.scenario.com/v1/models/${parameters.modelId}/inferences`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Basic ${scenarioAuthToken}`, - }, - body: JSON.stringify({ - parameters: { - enableSafetyCheck: false, - type: parameters?.referenceImage - ? "img2img" - : "txt2img", - prompt: `${parameters.prompt} ${ - supplementalPromptMap[parameters.modelId] - }`, - negativePrompt: "trading cards, cards", - numInferenceSteps: 30, - guidance: parameters.guidance, - width: 512, - height: 512, - numSamples: parameters.numImages, - image: parameters?.referenceImage ?? undefined, - modality: modalityMap[parameters.modelId] ?? undefined, - strength: parameters?.referenceImage - ? (100 - parameters?.influence) / 100 - : undefined, - }, + const endpoint = parameters?.referenceImage + ? "https://api.cloud.scenario.com/v1/generate/img2img" + : "https://api.cloud.scenario.com/v1/generate/txt2img" + + const generation: ScenarioInferenceResponse = await fetch(endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Basic ${scenarioAuthToken}`, + }, + body: JSON.stringify({ + modelId: parameters.modelId, + prompt: `${parameters.prompt} ${supplementalPromptMap[parameters.modelId] + }`, + negativePrompt: "trading cards, cards", + numInferenceSteps: 30, + guidance: parameters.guidance, + width: 512, + height: 512, + numSamples: parameters.numImages, + ...(parameters?.referenceImage && { + image: parameters.referenceImage, + strength: (100 - parameters?.influence) / 100, }), - } - ).then((res) => res.json()) + ...(modalityMap[parameters.modelId] && { + modality: modalityMap[parameters.modelId], + }), + }), + }).then((res) => res.json()) + await db.generation.create({ data: { diff --git a/components/create/generation-form.tsx b/components/create/generation-form.tsx index a1ca650..a53b04c 100644 --- a/components/create/generation-form.tsx +++ b/components/create/generation-form.tsx @@ -135,9 +135,8 @@ export function GenerationForm({ Base the entire prompt on this context: ${getValues( "prompt" - )} making sure to keep the style in mind which is: ${ - supplementalPromptMap[modelId] - }` + )} making sure to keep the style in mind which is: ${supplementalPromptMap[modelId] + }` const response = await fetch("/api/generate/prompt-generate", { method: "POST", @@ -246,7 +245,7 @@ export function GenerationForm({ ...prev, { guidance: guidance[0], - inferenceId: responseData.inference.id, + inferenceId: responseData.job.jobId, modelId, prompt: data.prompt, numImages, @@ -393,13 +392,13 @@ export function GenerationForm({ key as keyof typeof scenarioGenerators ] .featuredArtist && ( -
- - Featured - artist - -
- )} +
+ + Featured + artist + +
+ )} ) } diff --git a/components/create/generation-set.tsx b/components/create/generation-set.tsx index f88f2bf..86dcdac 100644 --- a/components/create/generation-set.tsx +++ b/components/create/generation-set.tsx @@ -64,9 +64,9 @@ export const GenerationSet = ({ let secondCount = 0 let showedPatienceModal = false setIsSaving(true) + while (!generatedImages) { - // Loop in 1s intervals until the alt text is ready - let finalResponse = await fetch( + const response = await fetch( `/api/generate/${inferenceId}?modelId=${modelId}`, { method: "GET", @@ -76,19 +76,23 @@ export const GenerationSet = ({ signal: controller.signal, } ) - let jsonFinalResponse: ScenarioInferenceProgressResponse = - await finalResponse.json() - setProgress(jsonFinalResponse.inference.progress) + const jsonResponse = await response.json() + + setProgress(jsonResponse.job?.progress || 0) if ( - jsonFinalResponse.inference.status === "succeeded" && - jsonFinalResponse?.outputImages + jsonResponse.job?.status === "success" && + jsonResponse?.outputImages ) { - generatedImages = jsonFinalResponse.outputImages + generatedImages = jsonResponse.outputImages + setImages(generatedImages) - } else if ( - jsonFinalResponse.inference.status === "failed" - ) { + } else if (jsonResponse.job?.status === "failed") { + toast({ + title: "Generation failed", + description: "Please try again", + variant: "destructive", + }) break } else { if (secondCount >= 60 && !showedPatienceModal) { @@ -101,15 +105,18 @@ export const GenerationSet = ({ showedPatienceModal = true } secondCount++ - await new Promise((resolve) => - setTimeout(resolve, 1000) - ) + await new Promise((resolve) => setTimeout(resolve, 1000)) } } setIsSaving(false) router.refresh() } catch (e) { console.log(e) + toast({ + title: "Something went wrong", + description: "Please try again", + variant: "destructive", + }) } } diff --git a/types/scenario.ts b/types/scenario.ts index 173a820..8c1e359 100644 --- a/types/scenario.ts +++ b/types/scenario.ts @@ -1,7 +1,36 @@ import { Generation, OutputImage } from "@prisma/client" export interface ScenarioInferenceResponse { - inference: Inference + job: Job +} + +export interface Job { + jobId: string + jobType: string + metadata: JobMetadata + ownerId: string + authorId: string + createdAt: string + updatedAt: string + status: string + statusHistory: StatusHistoryItem[] + progress: number +} + +export interface JobMetadata { + baseModelId: string + inferenceId: string + input: any // You might want to type this more specifically + modelId: string + modelType: string + priority: number + assetIds: string[] +} + +export interface StatusHistoryItem { + // Add specific fields based on your needs + status?: string + timestamp?: string } export interface Inference {