1. 模型
去除图片背景的大模型有很多,这里使用的大模型是 briaai/RMBG-2.0
它比 briaai/RMBG-1.4
版本性能提升了不少。
模型下载地址:https://huggingface.co/briaai/RMBG-2.0/resolve/main/onnx/model.onnx?download=true
2. 使用什么语言编写
大模型开发首选 Python
相对与其他语言 Python
的库和开发教程都比较多。这里我们使用 onnx
格式的模型。所以可选的语言就多了,onnx
支持 JavaScript
、Java
、C#
和 C++
。这里我们使用 Python
和 Java
来实现这个程序。
3. Python
版
Python
版本有使用 onnx
模型的也有使用其他模型,使用 gradio
创建一个UI界面,并打包成exe执行文件。下面只是使用 onnx
的代码。如果有需要其他部分代码的请留意。
import numpy as np
import onnxruntime
import onnx
import os
from PIL import Image
onnxruntime.set_default_logger_severity(3)
ONNX_DEVICE = onnxruntime.get_device()
ONNX_PROVIDER = (
"CUDAExecutionProvider" if ONNX_DEVICE == "GPU" else "CPUExecutionProvider"
)
def load_onnx_model(checkpoint_path, set_cpu=False):
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if ONNX_PROVIDER == "CUDAExecutionProvider"
else ["CPUExecutionProvider"]
)
if set_cpu:
sess = onnxruntime.InferenceSession(
checkpoint_path, providers=["CPUExecutionProvider"]
)
model = onnx.load(checkpoint_path)
print(f"Model IR version: {model.ir_version}")
else:
try:
sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
except Exception as e:
if ONNX_DEVICE == "CUDAExecutionProvider":
print(f"Failed to load model with CUDAExecutionProvider: {e}")
print("Falling back to CPUExecutionProvider")
# 尝试使用CPU加载模型
sess = onnxruntime.InferenceSession(
checkpoint_path, providers=["CPUExecutionProvider"]
)
else:
raise e # 如果是CPU执行失败,重新抛出异常
return sess
def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
return None
def resize_rmbg_image(image):
image = image.convert("RGB")
model_input_size = (ref_size, ref_size)
image = image.resize(model_input_size, Image.BILINEAR)
return image
RMBG_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
orig_image = Image.fromarray(input_image)
image = resize_rmbg_image(orig_image)
im_np = np.array(image).astype(np.float32)
im_np = im_np.transpose(2, 0, 1) # Change to CxHxW format
im_np = np.expand_dims(im_np, axis=0) # Add batch dimension
im_np = im_np / 255.0 # Normalize to [0, 1]
im_np = (im_np - 0.5) / 0.5 # Normalize to [-1, 1]
# Inference
result = RMBG_SESS.run(None, {RMBG_SESS.get_inputs()[0].name: im_np})[0]
# Post process
result = np.squeeze(result)
ma = np.max(result)
mi = np.min(result)
result = (result - mi) / (ma - mi) # Normalize to [0, 1]
# Convert to PIL image
im_array = (result * 255).astype(np.uint8)
pil_im = Image.fromarray(
im_array, mode="L"
) # Ensure mask is single channel (L mode)
# Resize the mask to match the original image size
pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)
# Paste the mask on the original image
new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
new_im.save("output1.png")
print('执行完毕')
# return np.array(new_im)
def read_modnet_image(input_image):
im = Image.open(input_image)
return np.array(im)
if __name__ == '__main__':
# 输入图片
input_image = r"D:\users\Desktop\v5notes\tray.png"
processing_image = read_modnet_image(input_image)
# 给定模型处理图片
get_rmbg_matting(processing_image, r'.\models\model.onnx')
4. Java版
Java
代码部分有使用 OpenCV
版本的有不使用的,下面代码是不使用部分。如果有需要其他部分代码的请留意。
ImageBackgroundRemover.java
package cn.v5cn.rebg.jdk;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.awt.image.BufferedImage;
import java.nio.FloatBuffer;
import java.util.Collections;
/**
* @author ZYW
*/
public class ImageBackgroundRemover {
private static final int INPUT_SIZE = 1024;
private final OrtEnvironment env;
private final OrtSession session;
public ImageBackgroundRemover(String modelPath) throws OrtException {
// 初始化ONNX Runtime环境
env = OrtEnvironment.getEnvironment();
session = env.createSession(modelPath);
}
public BufferedImage removeBackground(BufferedImage originalImage) throws OrtException {
// 预处理图像
float[] inputData = preprocessImage(originalImage);
// 创建输入tensor
OnnxTensor inputTensor = OnnxTensor.createTensor(env,
FloatBuffer.wrap(inputData),
new long[]{1, 3, INPUT_SIZE, INPUT_SIZE});
// 运行推理
OrtSession.Result result = session.run(Collections.singletonMap("pixel_values", inputTensor));
// 获取输出并后处理
float[][][][] output = (float[][][][]) result.get(0).getValue();
// 从4维数组中提取2维mask
float[][] mask = new float[INPUT_SIZE][INPUT_SIZE];
for (int i = 0; i < INPUT_SIZE; i++) {
for (int j = 0; j < INPUT_SIZE; j++) {
mask[i][j] = output[0][0][i][j];
}
}
// 生成蒙版并应用到原图
return applyMaskToImage(originalImage, mask);
}
private float[] preprocessImage(BufferedImage image) {
// 调整图像大小到1024x1024
BufferedImage resized = new BufferedImage(INPUT_SIZE, INPUT_SIZE, BufferedImage.TYPE_INT_RGB);
resized.getGraphics().drawImage(image, 0, 0, INPUT_SIZE, INPUT_SIZE, null);
float[] inputData = new float[3 * INPUT_SIZE * INPUT_SIZE];
int idx = 0;
// 归一化并转换为RGB格式
for (int c = 0; c < 3; c++) {
for (int h = 0; h < INPUT_SIZE; h++) {
for (int w = 0; w < INPUT_SIZE; w++) {
int pixel = resized.getRGB(w, h);
float value = ((pixel >> (16 - c * 8)) & 0xFF) / 255.0f;
inputData[idx++] = (value - 0.5f) / 0.5f;
}
}
}
return inputData;
}
private BufferedImage applyMaskToImage(BufferedImage original, float[][] mask) {
int width = original.getWidth();
int height = original.getHeight();
BufferedImage result = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
// 将mask调整到原始图像大小
float[][] resizedMask = resizeMask(mask, width, height);
// 应用mask
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int rgb = original.getRGB(x, y);
int alpha = (int) (resizedMask[y][x] * 255);
int newPixel = (alpha << 24) | (rgb & 0x00FFFFFF);
result.setRGB(x, y, newPixel);
}
}
return result;
}
private float[][] resizeMask(float[][] mask, int targetWidth, int targetHeight) {
float[][] resized = new float[targetHeight][targetWidth];
// 简单的双线性插值
for (int y = 0; y < targetHeight; y++) {
for (int x = 0; x < targetWidth; x++) {
float srcX = x * (float)INPUT_SIZE / targetWidth;
float srcY = y * (float)INPUT_SIZE / targetHeight;
resized[y][x] = bilinearInterpolate(mask, srcX, srcY);
}
}
return resized;
}
private float bilinearInterpolate(float[][] mask, float x, float y) {
int x1 = (int) Math.floor(x);
int x2 = Math.min(x1 + 1, INPUT_SIZE - 1);
int y1 = (int) Math.floor(y);
int y2 = Math.min(y1 + 1, INPUT_SIZE - 1);
float dx = x - x1;
float dy = y - y1;
float q11 = mask[y1][x1];
float q21 = mask[y1][x2];
float q12 = mask[y2][x1];
float q22 = mask[y2][x2];
return (1 - dx) * (1 - dy) * q11 + dx * (1 - dy) * q21 +
(1 - dx) * dy * q12 + dx * dy * q22;
}
public void close() throws OrtException {
if (session != null) {
session.close();
}
if (env != null) {
env.close();
}
}
}
Main.java
package cn.v5cn.rebg.jdk;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
public class Main {
public static void main(String[] args) {
try {
final String modelPath = "src/main/resources/models/rmbg-2.0.onnx"; // 替换为实际的模型路径
final String inputImage = "src/main/resources/images/tray.png"; // 替换为输入图片路径
final String outputImage = "src/main/resources/images/output.png"; // 替换为输出图片路径
// 初始化背景移除器
ImageBackgroundRemover remover = new ImageBackgroundRemover(modelPath);
// 读取输入图像
BufferedImage input = ImageIO.read(new File(inputImage));
// 移除背景
BufferedImage output = remover.removeBackground(input);
// 保存结果
ImageIO.write(output, "PNG", new File(outputImage));
// 清理资源
remover.close();
System.out.println("JDK========================背景移除成功!============================JDK");
} catch (Exception e) {
e.printStackTrace();
}
}
}
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>cn.v5cn.rebg</groupId>
<artifactId>v5cn-rebg</artifactId>
<version>1.0.0</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.20.0</version>
</dependency>
<dependency>
<groupId>org.openpnp</groupId>
<artifactId>opencv</artifactId>
<version>4.7.0-0</version>
</dependency>
</dependencies>
</project>
代码结构如下:
5. 总结
briaai/RMBG-1.4
是可以使用 Transformers.js
在浏览器中使用 WebGPU
实现移除图片背景。
https://images.batchtool.com/zh 这个站点就是使用 briaai/RMBG-1.4
完成图片处理的。
感谢阅读希望对您有所帮助。
评论区