Init upgrade scenario endpoint

This commit is contained in:
davidtparks 2025-01-11 15:33:58 -06:00
parent b58aa7f5f5
commit 8f7523e65e
5 changed files with 282 additions and 186 deletions

View file

@ -24,26 +24,41 @@ export const gridSizeToScenarioPixelMap = {
}
export const uploadImage = async (base64String: string) => {
const base64FileData = base64String.split("base64,")?.[1]
try {
const base64FileData = base64String.split("base64,")?.[1]
if (!base64FileData) {
throw new Error('Invalid base64 string format');
}
const uuid = uuidv4()
const { data: upload, error } = await supabase.storage
.from("pixelated")
.upload(`${uuid}.png`, decode(base64FileData), {
contentType: "image/png",
cacheControl: "3600",
upsert: false,
})
const uuid = uuidv4()
const { data: upload, error } = await supabase.storage
.from("pixelated")
.upload(`${uuid}.png`, decode(base64FileData), {
contentType: "image/png",
cacheControl: "3600",
upsert: false,
})
const { data } = await supabase.storage
.from("pixelated")
.getPublicUrl(`${uuid}.png`)
console.log("upload", upload)
if (error) {
throw new Error(error.message)
if (error) {
console.error('Supabase upload error:', error);
throw new Error(error.message)
}
const { data: urlData } = await supabase.storage
.from("pixelated")
.getPublicUrl(`${uuid}.png`)
console.log("urlData", urlData)
// Return an object with publicUrl property to match expected structure
return { publicUrl: urlData.publicUrl }
} catch (error) {
console.error('Upload image error:', error);
throw error;
}
return data
}
type PixelateImageParams = {
@ -63,15 +78,15 @@ export const pixelateImageScenario = async ({
removeBackground = false,
}: PixelateImageParams) => {
const pixelateResponse = await fetch(
`https://api.cloud.scenario.com/v1/images/pixelate`,
`https://api.cloud.scenario.com/v1/generate/pixelate`,
{
method: "PUT",
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Basic ${scenarioAuthToken}`,
},
body: JSON.stringify({
assetId,
image: assetId,
pixelGridSize: gridSizeToScenarioPixelMap[pixelGridSize],
returnImage: true,
removeNoise: true,
@ -84,6 +99,7 @@ export const pixelateImageScenario = async ({
}
)
if (!pixelateResponse.ok) {
return pixelateImage({
remoteUrl,
@ -93,7 +109,17 @@ export const pixelateImageScenario = async ({
const pixelateData: ScenarioPixelateResponse = await pixelateResponse.json()
return pixelateData.image
console.log("pixelate data response", pixelateData)
// @ts-ignore
return pixelateData.asset
}
async function urlToBase64(url: string) {
console.log(url)
const response = await fetch(url);
const buffer = await response.arrayBuffer();
return `data:image/png;base64,${Buffer.from(buffer).toString('base64')}`;
}
export async function GET(
@ -115,128 +141,165 @@ export async function GET(
return new Response(null, { status: 403 })
}
// Track the status of our inference progress here
const inferenceProgress: ScenarioInferenceProgressResponse =
await fetch(
`https://api.cloud.scenario.com/v1/models/${modelId}/inferences/${params.inferenceId}`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
Authorization: `Basic ${scenarioAuthToken}`,
},
}
).then((res) => res.json())
// If the inference was a success, decrement the user's credits.
if (inferenceProgress.inference.status === "succeeded") {
const generation = await db.generation.findUniqueOrThrow({
where: {
uniqueGeneration: {
inferenceId: params.inferenceId,
modelId: modelId,
},
// Use the new jobs endpoint instead of inference progress
const jobProgress = await fetch(
`https://api.cloud.scenario.com/v1/jobs/${params.inferenceId}`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
Authorization: `Basic ${scenarioAuthToken}`,
},
include: {
outputImages: true,
},
})
// If the generation is already complete, return the output images
if (generation.status === "COMPLETE") {
let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse =
{
...inferenceProgress,
outputImages: generation.outputImages,
}
return new Response(
JSON.stringify(copiedInferenceProgressWithImagesPixelated),
{ status: 200 }
)
}
).then((res) => res.json())
const pixelatedImagesScenario = await Promise.all(
inferenceProgress.inference.images.map((image) => {
if (generation.pixelSize === 32) {
return pixelateImage({
remoteUrl: image.url,
pixelSize: generation.pixelSize,
})
}
return pixelateImageScenario({
remoteUrl: image.url,
assetId: image.id,
pixelGridSize: generation.pixelSize,
colorPaletteEnabled: generation.colorPaletteEnabled,
colors: generation.colors as number[][],
})
})
)
console.log(jobProgress)
const pixelatedImages = await Promise.all(
pixelatedImagesScenario.map((image) => {
return uploadImage(image)
})
)
const imagesWithPixelated = inferenceProgress.inference.images.map(
(image, index) => {
return {
...image,
pixelated: pixelatedImages[index].publicUrl,
}
}
)
await db.$transaction([
db.user.update({
// If the job was successful, process the images and update credits
if (jobProgress.job.status === "success") {
console.log("Job success")
// Wrap this in try-catch since findUniqueOrThrow can fail
try {
const generation = await db.generation.findUniqueOrThrow({
where: {
id: session.user.id,
},
data: {
credits: {
decrement: generation.numSamples / 4,
uniqueGeneration: {
inferenceId: jobProgress.job.metadata.inferenceId,
modelId: modelId,
},
},
}),
db.generation.update({
where: {
id: generation.id,
include: {
outputImages: true,
},
data: {
status: "COMPLETE",
outputImages: {
createMany: {
data: imagesWithPixelated.map((image) => {
return {
scenarioImageId: image.id,
image: image.url,
seed: image.seed,
pixelatedImage: image.pixelated,
}
}),
})
console.log("generateion jere", generation)
// If already processed, return existing output images
if (generation.status === "COMPLETE") {
return new Response(
JSON.stringify({
...jobProgress,
outputImages: generation.outputImages,
}),
{ status: 200 }
)
}
// Process images using assetIds from job metadata
const assetIds = jobProgress.job.metadata.assetIds
const pixelatedImagesScenario = await Promise.all(
assetIds.map((assetId) => {
if (generation.pixelSize === 32) {
return pixelateImage({
remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`,
pixelSize: generation.pixelSize,
})
}
return pixelateImageScenario({
remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`,
assetId,
pixelGridSize: generation.pixelSize,
colorPaletteEnabled: generation.colorPaletteEnabled,
colors: generation.colors as number[][],
})
})
)
const pixelatedImages = await Promise.all(
pixelatedImagesScenario.map(async (imageResponse) => {
try {
// If it's a base64 string (from pixelateImage function)
if (typeof imageResponse === 'string') {
return uploadImage(imageResponse);
}
// If it's a response from scenario API
if (imageResponse?.url) {
const base64Data = await urlToBase64(imageResponse.url);
console.log("base64Data", base64Data)
return uploadImage(base64Data);
}
console.error('Invalid imageResponse:', imageResponse);
throw new Error('Invalid image response format');
} catch (error) {
console.error('Error processing image:', error);
throw error;
}
})
)
console.log("pixelateImages", pixelatedImages)
console.log("jobProgress.job.metadata.assetIds", jobProgress.job.metadata.assetIds)
const imagesWithPixelated = jobProgress.job.metadata.assetIds.map(
(assetId, index) => {
return {
scenarioImageId: assetId,
seed: "seed",
image: pixelatedImages[index].publicUrl,
pixelated: pixelatedImages[index].publicUrl,
}
}
)
console.log("imagesWithPixelated", imagesWithPixelated)
console.log("generation", generation)
await db.$transaction([
db.user.update({
where: {
id: session.user.id,
},
data: {
credits: {
decrement: generation.numSamples / 4,
},
},
}),
db.generation.update({
where: {
id: generation.id,
},
data: {
status: "COMPLETE",
outputImages: {
createMany: {
data: imagesWithPixelated.map((image) => {
return {
scenarioImageId: image.scenarioImageId,
image: image.url,
seed: image.seed,
pixelatedImage: image.pixelated,
}
}),
},
},
},
}),
])
const outputImages = await db.outputImage.findMany({
where: {
generationId: generation.id,
},
}),
])
})
const outputImages = await db.outputImage.findMany({
where: {
generationId: generation.id,
},
})
console.log("outputImages", outputImages)
let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse =
{
...inferenceProgress,
outputImages,
return new Response(
JSON.stringify({
...jobProgress,
outputImages,
}),
{ status: 200 }
)
} catch (error) {
// If generation record doesn't exist, return just the job progress
if (error.code === 'P2025') {
return new Response(JSON.stringify(jobProgress), { status: 200 })
}
return new Response(
JSON.stringify(copiedInferenceProgressWithImagesPixelated),
{ status: 200 }
)
} else if (inferenceProgress.inference.status === "failed") {
throw error // Re-throw other errors
}
} else if (jobProgress.job.status === "failed") {
const generation = await db.generation.findUniqueOrThrow({
where: {
uniqueGeneration: {
@ -256,7 +319,7 @@ export async function GET(
})
}
return new Response(JSON.stringify(inferenceProgress), { status: 200 })
return new Response(JSON.stringify(jobProgress), { status: 200 })
} catch (error) {
console.log("Error", error)
if (error instanceof z.ZodError) {

View file

@ -74,38 +74,36 @@ export async function POST(req: Request) {
)
}
const generation: ScenarioInferenceResponse = await fetch(
`https://api.cloud.scenario.com/v1/models/${parameters.modelId}/inferences`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Basic ${scenarioAuthToken}`,
},
body: JSON.stringify({
parameters: {
enableSafetyCheck: false,
type: parameters?.referenceImage
? "img2img"
: "txt2img",
prompt: `${parameters.prompt} ${
supplementalPromptMap[parameters.modelId]
}`,
negativePrompt: "trading cards, cards",
numInferenceSteps: 30,
guidance: parameters.guidance,
width: 512,
height: 512,
numSamples: parameters.numImages,
image: parameters?.referenceImage ?? undefined,
modality: modalityMap[parameters.modelId] ?? undefined,
strength: parameters?.referenceImage
? (100 - parameters?.influence) / 100
: undefined,
},
const endpoint = parameters?.referenceImage
? "https://api.cloud.scenario.com/v1/generate/img2img"
: "https://api.cloud.scenario.com/v1/generate/txt2img"
const generation: ScenarioInferenceResponse = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Basic ${scenarioAuthToken}`,
},
body: JSON.stringify({
modelId: parameters.modelId,
prompt: `${parameters.prompt} ${supplementalPromptMap[parameters.modelId]
}`,
negativePrompt: "trading cards, cards",
numInferenceSteps: 30,
guidance: parameters.guidance,
width: 512,
height: 512,
numSamples: parameters.numImages,
...(parameters?.referenceImage && {
image: parameters.referenceImage,
strength: (100 - parameters?.influence) / 100,
}),
}
).then((res) => res.json())
...(modalityMap[parameters.modelId] && {
modality: modalityMap[parameters.modelId],
}),
}),
}).then((res) => res.json())
await db.generation.create({
data: {

View file

@ -135,9 +135,8 @@ export function GenerationForm({
Base the entire prompt on this context: ${getValues(
"prompt"
)} making sure to keep the style in mind which is: ${
supplementalPromptMap[modelId]
}`
)} making sure to keep the style in mind which is: ${supplementalPromptMap[modelId]
}`
const response = await fetch("/api/generate/prompt-generate", {
method: "POST",
@ -246,7 +245,7 @@ export function GenerationForm({
...prev,
{
guidance: guidance[0],
inferenceId: responseData.inference.id,
inferenceId: responseData.job.jobId,
modelId,
prompt: data.prompt,
numImages,
@ -393,13 +392,13 @@ export function GenerationForm({
key as keyof typeof scenarioGenerators
]
.featuredArtist && (
<div className="inline-flex items-center ml-2">
<Badge variant="secondary">
Featured
artist
</Badge>
</div>
)}
<div className="inline-flex items-center ml-2">
<Badge variant="secondary">
Featured
artist
</Badge>
</div>
)}
</SelectItem>
)
}

View file

@ -64,9 +64,9 @@ export const GenerationSet = ({
let secondCount = 0
let showedPatienceModal = false
setIsSaving(true)
while (!generatedImages) {
// Loop in 1s intervals until the alt text is ready
let finalResponse = await fetch(
const response = await fetch(
`/api/generate/${inferenceId}?modelId=${modelId}`,
{
method: "GET",
@ -76,19 +76,23 @@ export const GenerationSet = ({
signal: controller.signal,
}
)
let jsonFinalResponse: ScenarioInferenceProgressResponse =
await finalResponse.json()
setProgress(jsonFinalResponse.inference.progress)
const jsonResponse = await response.json()
setProgress(jsonResponse.job?.progress || 0)
if (
jsonFinalResponse.inference.status === "succeeded" &&
jsonFinalResponse?.outputImages
jsonResponse.job?.status === "success" &&
jsonResponse?.outputImages
) {
generatedImages = jsonFinalResponse.outputImages
generatedImages = jsonResponse.outputImages
setImages(generatedImages)
} else if (
jsonFinalResponse.inference.status === "failed"
) {
} else if (jsonResponse.job?.status === "failed") {
toast({
title: "Generation failed",
description: "Please try again",
variant: "destructive",
})
break
} else {
if (secondCount >= 60 && !showedPatienceModal) {
@ -101,15 +105,18 @@ export const GenerationSet = ({
showedPatienceModal = true
}
secondCount++
await new Promise((resolve) =>
setTimeout(resolve, 1000)
)
await new Promise((resolve) => setTimeout(resolve, 1000))
}
}
setIsSaving(false)
router.refresh()
} catch (e) {
console.log(e)
toast({
title: "Something went wrong",
description: "Please try again",
variant: "destructive",
})
}
}

View file

@ -1,7 +1,36 @@
import { Generation, OutputImage } from "@prisma/client"
export interface ScenarioInferenceResponse {
inference: Inference
job: Job
}
export interface Job {
jobId: string
jobType: string
metadata: JobMetadata
ownerId: string
authorId: string
createdAt: string
updatedAt: string
status: string
statusHistory: StatusHistoryItem[]
progress: number
}
export interface JobMetadata {
baseModelId: string
inferenceId: string
input: any // You might want to type this more specifically
modelId: string
modelType: string
priority: number
assetIds: string[]
}
export interface StatusHistoryItem {
// Add specific fields based on your needs
status?: string
timestamp?: string
}
export interface Inference {