侧边栏壁纸
博主头像
平常心的blog 博主等级

行动起来,活在当下

  • 累计撰写 12 篇文章
  • 累计创建 16 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

给媳妇写一个去除图片背景的大模型应用

平常心
2024-12-05 / 0 评论 / 0 点赞 / 5 阅读 / 0 字

1. 模型

去除图片背景的大模型有很多,这里使用的大模型是 briaai/RMBG-2.0它比 briaai/RMBG-1.4版本性能提升了不少。

模型下载地址:https://huggingface.co/briaai/RMBG-2.0/resolve/main/onnx/model.onnx?download=true

2. 使用什么语言编写

大模型开发首选 Python相对与其他语言 Python的库和开发教程都比较多。这里我们使用 onnx格式的模型。所以可选的语言就多了,onnx支持 JavaScriptJavaC#C++。这里我们使用 PythonJava来实现这个程序。

3. Python

Python版本有使用 onnx模型的也有使用其他模型,使用 gradio创建一个UI界面,并打包成exe执行文件。下面只是使用 onnx的代码。如果有需要其他部分代码的请留意。

import numpy as np
import onnxruntime
import onnx
import os
from PIL import Image
onnxruntime.set_default_logger_severity(3)
ONNX_DEVICE = onnxruntime.get_device()
ONNX_PROVIDER = (
    "CUDAExecutionProvider" if ONNX_DEVICE == "GPU" else "CPUExecutionProvider"
)

def load_onnx_model(checkpoint_path, set_cpu=False):
    providers = (
        ["CUDAExecutionProvider", "CPUExecutionProvider"]
        if ONNX_PROVIDER == "CUDAExecutionProvider"
        else ["CPUExecutionProvider"]
    )

    if set_cpu:
        sess = onnxruntime.InferenceSession(
            checkpoint_path, providers=["CPUExecutionProvider"]
        )
        model = onnx.load(checkpoint_path)
        print(f"Model IR version: {model.ir_version}")
    else:
        try:
            sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
        except Exception as e:
            if ONNX_DEVICE == "CUDAExecutionProvider":
                print(f"Failed to load model with CUDAExecutionProvider: {e}")
                print("Falling back to CPUExecutionProvider")
                # 尝试使用CPU加载模型
                sess = onnxruntime.InferenceSession(
                    checkpoint_path, providers=["CPUExecutionProvider"]
                )
            else:
                raise e  # 如果是CPU执行失败,重新抛出异常

    return sess

def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):

    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint file not found: {checkpoint_path}")
        return None

    def resize_rmbg_image(image):
        image = image.convert("RGB")
        model_input_size = (ref_size, ref_size)
        image = image.resize(model_input_size, Image.BILINEAR)
        return image

    RMBG_SESS = load_onnx_model(checkpoint_path, set_cpu=True)

    orig_image = Image.fromarray(input_image)
    image = resize_rmbg_image(orig_image)
    im_np = np.array(image).astype(np.float32)
    im_np = im_np.transpose(2, 0, 1)  # Change to CxHxW format
    im_np = np.expand_dims(im_np, axis=0)  # Add batch dimension
    im_np = im_np / 255.0  # Normalize to [0, 1]
    im_np = (im_np - 0.5) / 0.5  # Normalize to [-1, 1]

    # Inference
    result = RMBG_SESS.run(None, {RMBG_SESS.get_inputs()[0].name: im_np})[0]

    # Post process
    result = np.squeeze(result)
    ma = np.max(result)
    mi = np.min(result)
    result = (result - mi) / (ma - mi)  # Normalize to [0, 1]

    # Convert to PIL image
    im_array = (result * 255).astype(np.uint8)
    pil_im = Image.fromarray(
        im_array, mode="L"
    )  # Ensure mask is single channel (L mode)

    # Resize the mask to match the original image size
    pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)

    # Paste the mask on the original image
    new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)
    new_im.save("output1.png")
    print('执行完毕')

    # return np.array(new_im)
def read_modnet_image(input_image):
    im = Image.open(input_image)

    return np.array(im)

if __name__ == '__main__':
    # 输入图片
    input_image = r"D:\users\Desktop\v5notes\tray.png"
    processing_image = read_modnet_image(input_image)
    # 给定模型处理图片
    get_rmbg_matting(processing_image, r'.\models\model.onnx')

4. Java版

Java代码部分有使用 OpenCV版本的有不使用的,下面代码是不使用部分。如果有需要其他部分代码的请留意。

ImageBackgroundRemover.java

package cn.v5cn.rebg.jdk;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import java.awt.image.BufferedImage;
import java.nio.FloatBuffer;
import java.util.Collections;

/**
 * @author ZYW
 */
public class ImageBackgroundRemover {
    private static final int INPUT_SIZE = 1024;
    private final OrtEnvironment env;
    private final OrtSession session;

    public ImageBackgroundRemover(String modelPath) throws OrtException {
        // 初始化ONNX Runtime环境
        env = OrtEnvironment.getEnvironment();
        session = env.createSession(modelPath);
    }

    public BufferedImage removeBackground(BufferedImage originalImage) throws OrtException {
        // 预处理图像
        float[] inputData = preprocessImage(originalImage);
      
        // 创建输入tensor
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, 
            FloatBuffer.wrap(inputData), 
            new long[]{1, 3, INPUT_SIZE, INPUT_SIZE});

        // 运行推理
        OrtSession.Result result = session.run(Collections.singletonMap("pixel_values", inputTensor));
      
        // 获取输出并后处理
        float[][][][] output = (float[][][][]) result.get(0).getValue();
        // 从4维数组中提取2维mask
        float[][] mask = new float[INPUT_SIZE][INPUT_SIZE];
        for (int i = 0; i < INPUT_SIZE; i++) {
            for (int j = 0; j < INPUT_SIZE; j++) {
                mask[i][j] = output[0][0][i][j];
            }
        }
      
        // 生成蒙版并应用到原图
        return applyMaskToImage(originalImage, mask);
    }

    private float[] preprocessImage(BufferedImage image) {
        // 调整图像大小到1024x1024
        BufferedImage resized = new BufferedImage(INPUT_SIZE, INPUT_SIZE, BufferedImage.TYPE_INT_RGB);
        resized.getGraphics().drawImage(image, 0, 0, INPUT_SIZE, INPUT_SIZE, null);
      
        float[] inputData = new float[3 * INPUT_SIZE * INPUT_SIZE];
        int idx = 0;
      
        // 归一化并转换为RGB格式
        for (int c = 0; c < 3; c++) {
            for (int h = 0; h < INPUT_SIZE; h++) {
                for (int w = 0; w < INPUT_SIZE; w++) {
                    int pixel = resized.getRGB(w, h);
                    float value = ((pixel >> (16 - c * 8)) & 0xFF) / 255.0f;
                    inputData[idx++] = (value - 0.5f) / 0.5f;
                }
            }
        }
      
        return inputData;
    }

    private BufferedImage applyMaskToImage(BufferedImage original, float[][] mask) {
        int width = original.getWidth();
        int height = original.getHeight();
        BufferedImage result = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
      
        // 将mask调整到原始图像大小
        float[][] resizedMask = resizeMask(mask, width, height);
      
        // 应用mask
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = original.getRGB(x, y);
                int alpha = (int) (resizedMask[y][x] * 255);
                int newPixel = (alpha << 24) | (rgb & 0x00FFFFFF);
                result.setRGB(x, y, newPixel);
            }
        }
      
        return result;
    }

    private float[][] resizeMask(float[][] mask, int targetWidth, int targetHeight) {
        float[][] resized = new float[targetHeight][targetWidth];
        // 简单的双线性插值
        for (int y = 0; y < targetHeight; y++) {
            for (int x = 0; x < targetWidth; x++) {
                float srcX = x * (float)INPUT_SIZE / targetWidth;
                float srcY = y * (float)INPUT_SIZE / targetHeight;
                resized[y][x] = bilinearInterpolate(mask, srcX, srcY);
            }
        }
        return resized;
    }

    private float bilinearInterpolate(float[][] mask, float x, float y) {
        int x1 = (int) Math.floor(x);
        int x2 = Math.min(x1 + 1, INPUT_SIZE - 1);
        int y1 = (int) Math.floor(y);
        int y2 = Math.min(y1 + 1, INPUT_SIZE - 1);
      
        float dx = x - x1;
        float dy = y - y1;
      
        float q11 = mask[y1][x1];
        float q21 = mask[y1][x2];
        float q12 = mask[y2][x1];
        float q22 = mask[y2][x2];
      
        return (1 - dx) * (1 - dy) * q11 + dx * (1 - dy) * q21 + 
               (1 - dx) * dy * q12 + dx * dy * q22;
    }

    public void close() throws OrtException {
        if (session != null) {
            session.close();
        }
        if (env != null) {
            env.close();
        }
    }
} 

Main.java

package cn.v5cn.rebg.jdk;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;

public class Main {
    public static void main(String[] args) {
        try {
            final String modelPath = "src/main/resources/models/rmbg-2.0.onnx";  // 替换为实际的模型路径
            final String inputImage = "src/main/resources/images/tray.png";     // 替换为输入图片路径
            final String outputImage = "src/main/resources/images/output.png";   // 替换为输出图片路径

            // 初始化背景移除器
            ImageBackgroundRemover remover = new ImageBackgroundRemover(modelPath);
            // 读取输入图像
            BufferedImage input = ImageIO.read(new File(inputImage));
            // 移除背景
            BufferedImage output = remover.removeBackground(input);
            // 保存结果
            ImageIO.write(output, "PNG", new File(outputImage));
            // 清理资源
            remover.close();

            System.out.println("JDK========================背景移除成功!============================JDK");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
} 

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>cn.v5cn.rebg</groupId>
    <artifactId>v5cn-rebg</artifactId>
    <version>1.0.0</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
    </properties>

    <dependencies>
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>1.20.0</version>
        </dependency>
        <dependency>
            <groupId>org.openpnp</groupId>
            <artifactId>opencv</artifactId>
            <version>4.7.0-0</version>
        </dependency>
    </dependencies>
</project> 

代码结构如下:

902e487a-d902-4f5f-be03-ab12e2e746b2.webp

5. 总结

briaai/RMBG-1.4是可以使用 Transformers.js在浏览器中使用 WebGPU实现移除图片背景。

https://images.batchtool.com/zh 这个站点就是使用 briaai/RMBG-1.4完成图片处理的。

感谢阅读希望对您有所帮助。

0

评论区