mirror of
https://github.com/shadcn-ui/taxonomy
synced 2026-05-24 09:48:32 +00:00
Init upgrade scenario endpoint
This commit is contained in:
parent
b58aa7f5f5
commit
8f7523e65e
5 changed files with 282 additions and 186 deletions
|
|
@ -24,26 +24,41 @@ export const gridSizeToScenarioPixelMap = {
|
|||
}
|
||||
|
||||
export const uploadImage = async (base64String: string) => {
|
||||
const base64FileData = base64String.split("base64,")?.[1]
|
||||
try {
|
||||
const base64FileData = base64String.split("base64,")?.[1]
|
||||
if (!base64FileData) {
|
||||
throw new Error('Invalid base64 string format');
|
||||
}
|
||||
|
||||
const uuid = uuidv4()
|
||||
const { data: upload, error } = await supabase.storage
|
||||
.from("pixelated")
|
||||
.upload(`${uuid}.png`, decode(base64FileData), {
|
||||
contentType: "image/png",
|
||||
cacheControl: "3600",
|
||||
upsert: false,
|
||||
})
|
||||
const uuid = uuidv4()
|
||||
const { data: upload, error } = await supabase.storage
|
||||
.from("pixelated")
|
||||
.upload(`${uuid}.png`, decode(base64FileData), {
|
||||
contentType: "image/png",
|
||||
cacheControl: "3600",
|
||||
upsert: false,
|
||||
})
|
||||
|
||||
const { data } = await supabase.storage
|
||||
.from("pixelated")
|
||||
.getPublicUrl(`${uuid}.png`)
|
||||
console.log("upload", upload)
|
||||
|
||||
if (error) {
|
||||
throw new Error(error.message)
|
||||
if (error) {
|
||||
console.error('Supabase upload error:', error);
|
||||
throw new Error(error.message)
|
||||
}
|
||||
|
||||
const { data: urlData } = await supabase.storage
|
||||
.from("pixelated")
|
||||
.getPublicUrl(`${uuid}.png`)
|
||||
|
||||
|
||||
console.log("urlData", urlData)
|
||||
|
||||
// Return an object with publicUrl property to match expected structure
|
||||
return { publicUrl: urlData.publicUrl }
|
||||
} catch (error) {
|
||||
console.error('Upload image error:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
type PixelateImageParams = {
|
||||
|
|
@ -63,15 +78,15 @@ export const pixelateImageScenario = async ({
|
|||
removeBackground = false,
|
||||
}: PixelateImageParams) => {
|
||||
const pixelateResponse = await fetch(
|
||||
`https://api.cloud.scenario.com/v1/images/pixelate`,
|
||||
`https://api.cloud.scenario.com/v1/generate/pixelate`,
|
||||
{
|
||||
method: "PUT",
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Basic ${scenarioAuthToken}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
assetId,
|
||||
image: assetId,
|
||||
pixelGridSize: gridSizeToScenarioPixelMap[pixelGridSize],
|
||||
returnImage: true,
|
||||
removeNoise: true,
|
||||
|
|
@ -84,6 +99,7 @@ export const pixelateImageScenario = async ({
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
if (!pixelateResponse.ok) {
|
||||
return pixelateImage({
|
||||
remoteUrl,
|
||||
|
|
@ -93,7 +109,17 @@ export const pixelateImageScenario = async ({
|
|||
|
||||
const pixelateData: ScenarioPixelateResponse = await pixelateResponse.json()
|
||||
|
||||
return pixelateData.image
|
||||
|
||||
console.log("pixelate data response", pixelateData)
|
||||
// @ts-ignore
|
||||
return pixelateData.asset
|
||||
}
|
||||
|
||||
async function urlToBase64(url: string) {
|
||||
console.log(url)
|
||||
const response = await fetch(url);
|
||||
const buffer = await response.arrayBuffer();
|
||||
return `data:image/png;base64,${Buffer.from(buffer).toString('base64')}`;
|
||||
}
|
||||
|
||||
export async function GET(
|
||||
|
|
@ -115,128 +141,165 @@ export async function GET(
|
|||
return new Response(null, { status: 403 })
|
||||
}
|
||||
|
||||
// Track the status of our inference progress here
|
||||
const inferenceProgress: ScenarioInferenceProgressResponse =
|
||||
await fetch(
|
||||
`https://api.cloud.scenario.com/v1/models/${modelId}/inferences/${params.inferenceId}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Basic ${scenarioAuthToken}`,
|
||||
},
|
||||
}
|
||||
).then((res) => res.json())
|
||||
|
||||
// If the inference was a success, decrement the user's credits.
|
||||
if (inferenceProgress.inference.status === "succeeded") {
|
||||
const generation = await db.generation.findUniqueOrThrow({
|
||||
where: {
|
||||
uniqueGeneration: {
|
||||
inferenceId: params.inferenceId,
|
||||
modelId: modelId,
|
||||
},
|
||||
// Use the new jobs endpoint instead of inference progress
|
||||
const jobProgress = await fetch(
|
||||
`https://api.cloud.scenario.com/v1/jobs/${params.inferenceId}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Basic ${scenarioAuthToken}`,
|
||||
},
|
||||
include: {
|
||||
outputImages: true,
|
||||
},
|
||||
})
|
||||
|
||||
// If the generation is already complete, return the output images
|
||||
if (generation.status === "COMPLETE") {
|
||||
let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse =
|
||||
{
|
||||
...inferenceProgress,
|
||||
outputImages: generation.outputImages,
|
||||
}
|
||||
return new Response(
|
||||
JSON.stringify(copiedInferenceProgressWithImagesPixelated),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
).then((res) => res.json())
|
||||
|
||||
const pixelatedImagesScenario = await Promise.all(
|
||||
inferenceProgress.inference.images.map((image) => {
|
||||
if (generation.pixelSize === 32) {
|
||||
return pixelateImage({
|
||||
remoteUrl: image.url,
|
||||
pixelSize: generation.pixelSize,
|
||||
})
|
||||
}
|
||||
return pixelateImageScenario({
|
||||
remoteUrl: image.url,
|
||||
assetId: image.id,
|
||||
pixelGridSize: generation.pixelSize,
|
||||
colorPaletteEnabled: generation.colorPaletteEnabled,
|
||||
colors: generation.colors as number[][],
|
||||
})
|
||||
})
|
||||
)
|
||||
console.log(jobProgress)
|
||||
|
||||
const pixelatedImages = await Promise.all(
|
||||
pixelatedImagesScenario.map((image) => {
|
||||
return uploadImage(image)
|
||||
})
|
||||
)
|
||||
|
||||
const imagesWithPixelated = inferenceProgress.inference.images.map(
|
||||
(image, index) => {
|
||||
return {
|
||||
...image,
|
||||
pixelated: pixelatedImages[index].publicUrl,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await db.$transaction([
|
||||
db.user.update({
|
||||
// If the job was successful, process the images and update credits
|
||||
if (jobProgress.job.status === "success") {
|
||||
console.log("Job success")
|
||||
// Wrap this in try-catch since findUniqueOrThrow can fail
|
||||
try {
|
||||
const generation = await db.generation.findUniqueOrThrow({
|
||||
where: {
|
||||
id: session.user.id,
|
||||
},
|
||||
data: {
|
||||
credits: {
|
||||
decrement: generation.numSamples / 4,
|
||||
uniqueGeneration: {
|
||||
inferenceId: jobProgress.job.metadata.inferenceId,
|
||||
modelId: modelId,
|
||||
},
|
||||
},
|
||||
}),
|
||||
db.generation.update({
|
||||
where: {
|
||||
id: generation.id,
|
||||
include: {
|
||||
outputImages: true,
|
||||
},
|
||||
data: {
|
||||
status: "COMPLETE",
|
||||
outputImages: {
|
||||
createMany: {
|
||||
data: imagesWithPixelated.map((image) => {
|
||||
return {
|
||||
scenarioImageId: image.id,
|
||||
image: image.url,
|
||||
seed: image.seed,
|
||||
pixelatedImage: image.pixelated,
|
||||
}
|
||||
}),
|
||||
})
|
||||
|
||||
console.log("generateion jere", generation)
|
||||
|
||||
// If already processed, return existing output images
|
||||
if (generation.status === "COMPLETE") {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
...jobProgress,
|
||||
outputImages: generation.outputImages,
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
|
||||
// Process images using assetIds from job metadata
|
||||
const assetIds = jobProgress.job.metadata.assetIds
|
||||
const pixelatedImagesScenario = await Promise.all(
|
||||
assetIds.map((assetId) => {
|
||||
if (generation.pixelSize === 32) {
|
||||
return pixelateImage({
|
||||
remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`,
|
||||
pixelSize: generation.pixelSize,
|
||||
})
|
||||
}
|
||||
return pixelateImageScenario({
|
||||
remoteUrl: `https://api.cloud.scenario.com/v1/assets/${assetId}`,
|
||||
assetId,
|
||||
pixelGridSize: generation.pixelSize,
|
||||
colorPaletteEnabled: generation.colorPaletteEnabled,
|
||||
colors: generation.colors as number[][],
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
const pixelatedImages = await Promise.all(
|
||||
pixelatedImagesScenario.map(async (imageResponse) => {
|
||||
try {
|
||||
// If it's a base64 string (from pixelateImage function)
|
||||
if (typeof imageResponse === 'string') {
|
||||
return uploadImage(imageResponse);
|
||||
}
|
||||
// If it's a response from scenario API
|
||||
if (imageResponse?.url) {
|
||||
const base64Data = await urlToBase64(imageResponse.url);
|
||||
console.log("base64Data", base64Data)
|
||||
return uploadImage(base64Data);
|
||||
}
|
||||
console.error('Invalid imageResponse:', imageResponse);
|
||||
throw new Error('Invalid image response format');
|
||||
} catch (error) {
|
||||
console.error('Error processing image:', error);
|
||||
throw error;
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
console.log("pixelateImages", pixelatedImages)
|
||||
console.log("jobProgress.job.metadata.assetIds", jobProgress.job.metadata.assetIds)
|
||||
|
||||
const imagesWithPixelated = jobProgress.job.metadata.assetIds.map(
|
||||
(assetId, index) => {
|
||||
return {
|
||||
scenarioImageId: assetId,
|
||||
seed: "seed",
|
||||
image: pixelatedImages[index].publicUrl,
|
||||
pixelated: pixelatedImages[index].publicUrl,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
console.log("imagesWithPixelated", imagesWithPixelated)
|
||||
console.log("generation", generation)
|
||||
|
||||
await db.$transaction([
|
||||
db.user.update({
|
||||
where: {
|
||||
id: session.user.id,
|
||||
},
|
||||
data: {
|
||||
credits: {
|
||||
decrement: generation.numSamples / 4,
|
||||
},
|
||||
},
|
||||
}),
|
||||
db.generation.update({
|
||||
where: {
|
||||
id: generation.id,
|
||||
},
|
||||
data: {
|
||||
status: "COMPLETE",
|
||||
outputImages: {
|
||||
createMany: {
|
||||
data: imagesWithPixelated.map((image) => {
|
||||
return {
|
||||
scenarioImageId: image.scenarioImageId,
|
||||
image: image.url,
|
||||
seed: image.seed,
|
||||
pixelatedImage: image.pixelated,
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
])
|
||||
|
||||
const outputImages = await db.outputImage.findMany({
|
||||
where: {
|
||||
generationId: generation.id,
|
||||
},
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
const outputImages = await db.outputImage.findMany({
|
||||
where: {
|
||||
generationId: generation.id,
|
||||
},
|
||||
})
|
||||
console.log("outputImages", outputImages)
|
||||
|
||||
let copiedInferenceProgressWithImagesPixelated: ScenarioInferenceProgressResponse =
|
||||
{
|
||||
...inferenceProgress,
|
||||
outputImages,
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
...jobProgress,
|
||||
outputImages,
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error) {
|
||||
// If generation record doesn't exist, return just the job progress
|
||||
if (error.code === 'P2025') {
|
||||
return new Response(JSON.stringify(jobProgress), { status: 200 })
|
||||
}
|
||||
return new Response(
|
||||
JSON.stringify(copiedInferenceProgressWithImagesPixelated),
|
||||
{ status: 200 }
|
||||
)
|
||||
} else if (inferenceProgress.inference.status === "failed") {
|
||||
throw error // Re-throw other errors
|
||||
}
|
||||
} else if (jobProgress.job.status === "failed") {
|
||||
const generation = await db.generation.findUniqueOrThrow({
|
||||
where: {
|
||||
uniqueGeneration: {
|
||||
|
|
@ -256,7 +319,7 @@ export async function GET(
|
|||
})
|
||||
}
|
||||
|
||||
return new Response(JSON.stringify(inferenceProgress), { status: 200 })
|
||||
return new Response(JSON.stringify(jobProgress), { status: 200 })
|
||||
} catch (error) {
|
||||
console.log("Error", error)
|
||||
if (error instanceof z.ZodError) {
|
||||
|
|
|
|||
|
|
@ -74,38 +74,36 @@ export async function POST(req: Request) {
|
|||
)
|
||||
}
|
||||
|
||||
const generation: ScenarioInferenceResponse = await fetch(
|
||||
`https://api.cloud.scenario.com/v1/models/${parameters.modelId}/inferences`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Basic ${scenarioAuthToken}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
parameters: {
|
||||
enableSafetyCheck: false,
|
||||
type: parameters?.referenceImage
|
||||
? "img2img"
|
||||
: "txt2img",
|
||||
prompt: `${parameters.prompt} ${
|
||||
supplementalPromptMap[parameters.modelId]
|
||||
}`,
|
||||
negativePrompt: "trading cards, cards",
|
||||
numInferenceSteps: 30,
|
||||
guidance: parameters.guidance,
|
||||
width: 512,
|
||||
height: 512,
|
||||
numSamples: parameters.numImages,
|
||||
image: parameters?.referenceImage ?? undefined,
|
||||
modality: modalityMap[parameters.modelId] ?? undefined,
|
||||
strength: parameters?.referenceImage
|
||||
? (100 - parameters?.influence) / 100
|
||||
: undefined,
|
||||
},
|
||||
const endpoint = parameters?.referenceImage
|
||||
? "https://api.cloud.scenario.com/v1/generate/img2img"
|
||||
: "https://api.cloud.scenario.com/v1/generate/txt2img"
|
||||
|
||||
const generation: ScenarioInferenceResponse = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Basic ${scenarioAuthToken}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
modelId: parameters.modelId,
|
||||
prompt: `${parameters.prompt} ${supplementalPromptMap[parameters.modelId]
|
||||
}`,
|
||||
negativePrompt: "trading cards, cards",
|
||||
numInferenceSteps: 30,
|
||||
guidance: parameters.guidance,
|
||||
width: 512,
|
||||
height: 512,
|
||||
numSamples: parameters.numImages,
|
||||
...(parameters?.referenceImage && {
|
||||
image: parameters.referenceImage,
|
||||
strength: (100 - parameters?.influence) / 100,
|
||||
}),
|
||||
}
|
||||
).then((res) => res.json())
|
||||
...(modalityMap[parameters.modelId] && {
|
||||
modality: modalityMap[parameters.modelId],
|
||||
}),
|
||||
}),
|
||||
}).then((res) => res.json())
|
||||
|
||||
|
||||
await db.generation.create({
|
||||
data: {
|
||||
|
|
|
|||
|
|
@ -135,9 +135,8 @@ export function GenerationForm({
|
|||
|
||||
Base the entire prompt on this context: ${getValues(
|
||||
"prompt"
|
||||
)} making sure to keep the style in mind which is: ${
|
||||
supplementalPromptMap[modelId]
|
||||
}`
|
||||
)} making sure to keep the style in mind which is: ${supplementalPromptMap[modelId]
|
||||
}`
|
||||
|
||||
const response = await fetch("/api/generate/prompt-generate", {
|
||||
method: "POST",
|
||||
|
|
@ -246,7 +245,7 @@ export function GenerationForm({
|
|||
...prev,
|
||||
{
|
||||
guidance: guidance[0],
|
||||
inferenceId: responseData.inference.id,
|
||||
inferenceId: responseData.job.jobId,
|
||||
modelId,
|
||||
prompt: data.prompt,
|
||||
numImages,
|
||||
|
|
@ -393,13 +392,13 @@ export function GenerationForm({
|
|||
key as keyof typeof scenarioGenerators
|
||||
]
|
||||
.featuredArtist && (
|
||||
<div className="inline-flex items-center ml-2">
|
||||
<Badge variant="secondary">
|
||||
Featured
|
||||
artist
|
||||
</Badge>
|
||||
</div>
|
||||
)}
|
||||
<div className="inline-flex items-center ml-2">
|
||||
<Badge variant="secondary">
|
||||
Featured
|
||||
artist
|
||||
</Badge>
|
||||
</div>
|
||||
)}
|
||||
</SelectItem>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,9 +64,9 @@ export const GenerationSet = ({
|
|||
let secondCount = 0
|
||||
let showedPatienceModal = false
|
||||
setIsSaving(true)
|
||||
|
||||
while (!generatedImages) {
|
||||
// Loop in 1s intervals until the alt text is ready
|
||||
let finalResponse = await fetch(
|
||||
const response = await fetch(
|
||||
`/api/generate/${inferenceId}?modelId=${modelId}`,
|
||||
{
|
||||
method: "GET",
|
||||
|
|
@ -76,19 +76,23 @@ export const GenerationSet = ({
|
|||
signal: controller.signal,
|
||||
}
|
||||
)
|
||||
let jsonFinalResponse: ScenarioInferenceProgressResponse =
|
||||
await finalResponse.json()
|
||||
setProgress(jsonFinalResponse.inference.progress)
|
||||
const jsonResponse = await response.json()
|
||||
|
||||
setProgress(jsonResponse.job?.progress || 0)
|
||||
|
||||
if (
|
||||
jsonFinalResponse.inference.status === "succeeded" &&
|
||||
jsonFinalResponse?.outputImages
|
||||
jsonResponse.job?.status === "success" &&
|
||||
jsonResponse?.outputImages
|
||||
) {
|
||||
generatedImages = jsonFinalResponse.outputImages
|
||||
generatedImages = jsonResponse.outputImages
|
||||
|
||||
setImages(generatedImages)
|
||||
} else if (
|
||||
jsonFinalResponse.inference.status === "failed"
|
||||
) {
|
||||
} else if (jsonResponse.job?.status === "failed") {
|
||||
toast({
|
||||
title: "Generation failed",
|
||||
description: "Please try again",
|
||||
variant: "destructive",
|
||||
})
|
||||
break
|
||||
} else {
|
||||
if (secondCount >= 60 && !showedPatienceModal) {
|
||||
|
|
@ -101,15 +105,18 @@ export const GenerationSet = ({
|
|||
showedPatienceModal = true
|
||||
}
|
||||
secondCount++
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, 1000)
|
||||
)
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||||
}
|
||||
}
|
||||
setIsSaving(false)
|
||||
router.refresh()
|
||||
} catch (e) {
|
||||
console.log(e)
|
||||
toast({
|
||||
title: "Something went wrong",
|
||||
description: "Please try again",
|
||||
variant: "destructive",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,36 @@
|
|||
import { Generation, OutputImage } from "@prisma/client"
|
||||
|
||||
export interface ScenarioInferenceResponse {
|
||||
inference: Inference
|
||||
job: Job
|
||||
}
|
||||
|
||||
export interface Job {
|
||||
jobId: string
|
||||
jobType: string
|
||||
metadata: JobMetadata
|
||||
ownerId: string
|
||||
authorId: string
|
||||
createdAt: string
|
||||
updatedAt: string
|
||||
status: string
|
||||
statusHistory: StatusHistoryItem[]
|
||||
progress: number
|
||||
}
|
||||
|
||||
export interface JobMetadata {
|
||||
baseModelId: string
|
||||
inferenceId: string
|
||||
input: any // You might want to type this more specifically
|
||||
modelId: string
|
||||
modelType: string
|
||||
priority: number
|
||||
assetIds: string[]
|
||||
}
|
||||
|
||||
export interface StatusHistoryItem {
|
||||
// Add specific fields based on your needs
|
||||
status?: string
|
||||
timestamp?: string
|
||||
}
|
||||
|
||||
export interface Inference {
|
||||
|
|
|
|||
Loading…
Reference in a new issue