ashim/packages/ai/python/remove_bg.py
Ashim 08a7ffe403 Enhance logging and error handling across tools; add full tool audit and Playwright tests
- Added model mismatch warnings in colorize, enhance-faces, and upscale routes.
- Improved error handling in colorize, enhance_faces, remove_bg, restore, and upscale scripts with detailed logging.
- Updated Dockerfile to align NCCL versions for compatibility.
- Introduced a new full tool audit script to test all tools for functionality and GPU usage.
- Created Playwright E2E tests for GPU-dependent tools to ensure proper functionality and performance.
2026-04-17 23:06:31 +08:00

130 lines
4 KiB
Python

"""Background removal using rembg with state-of-the-art BiRefNet models."""
import sys
import json
import os
def emit_progress(percent, stage):
"""Emit structured progress to stderr for bridge.ts to capture."""
print(json.dumps({"progress": percent, "stage": stage}), file=sys.stderr, flush=True)
ALLOWED_MODELS = {
"u2net",
"isnet-general-use",
"bria-rmbg",
"birefnet-general-lite",
"birefnet-portrait",
"birefnet-general",
"birefnet-matting",
}
_matting_registered = False
def _register_matting_session(sessions_class):
"""Register the BiRefNet-matting ONNX session for Ultra quality mode."""
global _matting_registered
if _matting_registered:
return
_matting_registered = True
import os
import pooch
from rembg.sessions.birefnet_general import BiRefNetSessionGeneral
class BiRefNetMattingSession(BiRefNetSessionGeneral):
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-matting-epoch_100.onnx",
None, # Skip checksum for GitHub release assets
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
return "birefnet-matting"
sessions_class.append(BiRefNetMattingSession)
def main():
input_path = sys.argv[1]
output_path = sys.argv[2]
settings = json.loads(sys.argv[3]) if len(sys.argv) > 3 else {}
model = settings.get("model", "birefnet-general-lite")
if model not in ALLOWED_MODELS:
model = "birefnet-general-lite"
# Redirect stdout to stderr so library download/progress output
# cannot contaminate our JSON result on stdout.
stdout_fd = os.dup(1)
os.dup2(2, 1)
try:
from rembg import remove, new_session
from rembg.sessions import sessions_class
from gpu import onnx_providers
# Register BiRefNet-matting (Ultra quality) if not already present
_register_matting_session(sessions_class)
emit_progress(10, "Loading model")
session = new_session(model, providers=onnx_providers())
emit_progress(25, "Model loaded")
with open(input_path, "rb") as f:
input_data = f.read()
# Try with alpha matting for better edges, fall back without
emit_progress(30, "Analyzing image")
try:
output_data = remove(
input_data,
session=session,
alpha_matting=True,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
)
except Exception as e:
print(f"[remove-bg] Alpha matting failed ({e}), using standard removal", file=sys.stderr, flush=True)
output_data = remove(input_data, session=session)
emit_progress(80, "Background removed")
# Always return transparent PNG. All background compositing
# (solid color, gradient, blur, shadow) is handled by Node.js/Sharp.
emit_progress(95, "Saving result")
with open(output_path, "wb") as f:
f.write(output_data)
result = json.dumps({"success": True, "model": model})
except ImportError as e:
print(f"[remove-bg] Import failed: {e}", file=sys.stderr, flush=True)
result = json.dumps(
{
"success": False,
"error": f"rembg import failed: {e}",
}
)
except Exception as e:
result = json.dumps({"success": False, "error": str(e)})
# Restore original stdout and write only our JSON result
os.dup2(stdout_fd, 1)
os.close(stdout_fd)
sys.stdout.write(result + "\n")
sys.stdout.flush()
if __name__ == "__main__":
main()