感谢电子发烧友和爱芯元智公司提供的测试机会。
前面介绍了分割图像的SAM框架在爱芯派 Pro (AXera-Pi Pro)开发板的测试结果,今天来展示一下对SAM程序的修改,使它成为一个交互式抠图软件。
SAM的编译
我们先介绍一下如何在开发板上编译原始的SAM程序原始代码由爱芯官方开源于 GITHUB:https://github.com/AXERA-TECH/SAM-ONNX-AX650-CPP。
编译前,需要先安装一些相关的软件:
apt update
apt install build-essential
apt install cmake
如果希望使用Qt6开发用户界面,还需要安装以下软件:
apt install qt6-base-dev qtcreator
将程序代码解压,假定解压到/root/Desktop/SAM-ONNX-AX650-CPP-1.1目录。设置如下环境变量:
export onnxruntime_dir=/root/Desktop/SAM-ONNX-AX650-CPP-1.1/third_party/onnxruntime-aarch64-none-gnu-1.16.0/
export opencv_cmake_file_dir=/root/Desktop/SAM-ONNX-AX650-CPP-1.1/third_party/libopencv-4.6-aarch64-none/lib/cmake/opencv4/
export msp_out_dir=/soc/
如果只想编译无界面版本,先进入AXERA-TECH/SAM-ONNX-AX650-CPP目录,然后执行以下步骤:
mkdir build
cd build
cmake -DONNXRUNTIME_DIR=${onnxruntime_dir} -DOpenCV_DIR=${opencv_cmake_file_dir} -DBSP_MSP_DIR=${msp_out_dir} -DBUILD_WITH_AX650=ON ..
make -j4
下面是编译成功的截图。
如果只想编译Qt6版本,先进入AXERA-TECH/SAM-ONNX-AX650-CPP目录,然后执行以下步骤:
cd qtproj
mkdir build
cd build
cmake -DONNXRUNTIME_DIR=${onnxruntime_dir} -DOpenCV_DIR=${opencv_cmake_file_dir} -DBSP_MSP_DIR=${msp_out_dir} -DBUILD_WITH_AX650=ON ../SAMQT/
make -j4
CMake在生成Makefile的过程中会有一个找不到Qt的警告,可以忽略。下面是编译成功的截图。
SAM程序的工作原理
SAM-ONNX-AX650-CPP-1.1/src目录下是SAM和LaMa的核心代码,它们会被编译成为libsam.a库文件,这部分不是我们要修改的。
目录qtprj/SAMQT下的代码是Qt6部分,我们只需要在其中做一些修改即可。
其中mainwindow.ui是用户界面文件,可以使用QtCreator进行修改。文件mainwindow.cpp是这个界面所对应的响应代码,其中的关于SAM的代码很少,主要是Qt的各种事件处理。
程序的核心代码其实是myqlabel.h。这里面包括对SAM的调用,我们需要分析一下这部分代码并进行修改。
paintEvent函数主要用于绘制图片,包括和对用户选择框和选择对象的处理。
samDecode函数调用libsam库进行解码操作,将SAM识别出来的对象中可信度最高的掩码都加入掩码向量中。该函数是在mouseReleaseEvent事件中被调用。
ShowRemoveObject函数调用LamaInpaint对象,对指定掩码区域进行图像填充。我们重点要修改这个函数。
抠图软件的实现
抠图软件的的工作原理就是只复制掩码区域的图像,而不复制掩码区域以外的图像。对此我们可以借助于OpenCV的bitwise_and函数。该函数的作用是将两幅图像进行按位与操作。函数的定义如下:
bitwise_and(InputArray src1, InputArray src2,OutputArray dst, InputArray mask=noArray()); //dst = src1 & src2
前两个参数是参与运算的图像,第3个参数是保存结果的图像,最后一个参数是掩码。我们可以先建立一个全白色的图像和原始图像进行与操作,此处的掩码就可以采用SAM输出的掩码图像。利用我们对足球场的图像进行选取其中两名运动员的操作,然后执行抠图操作就可以得到去除背景的图片。
对.ui的修改可以借助Qt Creator完成,主要是修改了窗口和按钮的标题。重点是修改了ShowRemoveObject函数,这里给出完整的代码。
void ShowRemoveObject(int dilate_size, QProgressBar *bar, bool remove_mask_by_merge = true)
{
if (!cur_image.bits() || !grab_masks.size())
{
return;
}
int channel = cur_image.format() == QImage::Format_BGR888 ? 3 : 4;
int stride = cur_image.bytesPerLine();
cv::Mat src(cur_image.height(), cur_image.width(), CV_8UC(channel), cur_image.bits(), stride);
cv::Mat rgb;
if (channel == 3)
src.copyTo(rgb);
else if (channel == 4)
cv::cvtColor(src, rgb, cv::COLOR_RGBA2RGB);
pt_img_first = QPoint(-10000, -10000);
pt_img_secend = QPoint(-10000, -10000);
cv::Mat cropped(rgb.size(), CV_8UC(3));
cropped.setTo(cv::Scalar(255, 255, 255));
if (bar)
{
bar->setValue(0);
bar->setMinimum(0);
bar->setMaximum(grab_masks.size());
}
if (remove_mask_by_merge)
{
if (grab_masks.size())
{
auto base_mask = grab_masks[0];
if (bar)
bar->setValue(bar->value() + 1);
for (size_t i = 1; i < grab_masks.size(); i++)
{
base_mask |= grab_masks[i];
if (bar)
bar->setValue(bar->value() + 1);
}
auto time_start = std::chrono::high_resolution_clock::now();
cv::bitwise_and(cropped, rgb, cropped, base_mask);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Crop Cost time : " << diff.count() << "s" << std::endl;
QImage qcropped(cropped.data, cropped.cols, cropped.rows, cropped.step1(), QImage::Format_BGR888);
cur_image = qcropped.copy();
cur_masks.clear();
repaint();
}
}
else
{
cv::Mat r;
for (auto grab_mask : grab_masks)
{
auto time_start = std::chrono::high_resolution_clock::now();
cv::bitwise_and(cropped, rgb, cropped, grab_mask);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Crop Cost time : " << diff.count() << "s" << std::endl;
QImage qcropped(cropped.data, cropped.cols, cropped.rows, cropped.step1(), QImage::Format_BGR888);
cur_image = qcropped.copy();
if (cur_masks.size())
cur_masks.removeFirst();
repaint();
if (bar)
bar->setValue(bar->value() + 1);
}
}
cur_masks.clear();
rgba_masks.clear();
grab_masks.clear();
mSam.Encode(cropped);
}
最后的测试结果见下面的视频。更完整的版本请参考我的B站视频。