LongLong's Blog

分享IT技术,分享生活感悟,热爱摄影,热爱航天。

使用Tensorflow检测自定义物体

随着深度学习方法的不断发展,越来越多的问题可以借助深度学习的模型来处理。对于一些物体检测特别是人脸、车辆等的检测,使用OpenCV等图像处理工具已经可以很好的处理,而对于自定义的物体,如有些标记,物体的特征等,需要利用自行训练神经网络来进行处理,儿在Tensorflow中给出了一个进行物体检测模型,可以用来训练自己的神经网络。关于该模型使用的文章和教程虽然已经比较多,但在使用的过程中还是遇到了不少问题,这里简单总结。

1. Tensorflow版本

虽然官方的文档中说可以支持Tensorflow2,但实际使用的过程中发现在运行到训练模型的时候,会因为模块tensorflow.contrib在Tensorflow2中被移除而在使用过程中出现错误,并且暂时没有查到可以兼容的方法,以下的代码让人比较无奈。为此暂时只能使用Tensorflow 1.14版本

try:
  from tensorflow.contrib import opt as tf_opt  # pylint: disable=g-import-not-at-top
except:  # pylint: disable=bare-except
  pass

2. 手动标记数据

为了训练神经网络,需要首先手动标记需要识别的目标的数据,这里使用的工具是LabelImg。使用该工具打开需要进行标记的图片,将图面中需要进行识别的部分用方框选中,并添加对应的标签,每标记一张图并进行保存后会生成一个xml文件记录对应的标记信息。最后 将标记完成后的图片和对应的xml分别放入train和test两个目录中作为训练集和测试集。

3. 生成训练数据

之后需要将得到的数据转化为Tensorflow训练中所用的TFRecord格式,尝试了自带的脚本create_pascal_tf_record.py,但由于其没有详细说明具体使用方法,只是给了一个针对2012 PASCAL VOC数据的转换命令,没有个给出参数的具体含义,这里使用了一般的教程中给出的转换方法,即首先将标记的数据转换为csv文件,然后再将csv文件和图片一起转换为TFRecord格式(TFRecord文件中就已经包含了图片本身的信息)。涉及两个转换的脚本如下

#XML转csv
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET

def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        print(xml_file)
        tree = ET.parse(xml_file)
        root = tree.getroot()
        i = 0;
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height',
                   'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df

def main():
    for directory in ['train', 'test']:
        image_path = os.path.join(os.getcwd(), 'images/{}'.format(directory))
        xml_df = xml_to_csv(image_path)
        print(image_path);
        print('data/{}_labels.csv'.format(directory));
        xml_df.to_csv('data/{}_labels.csv'.format(directory), index=None)
        print('Successfully converted xml to csv.')

if __name__ == '__main__':
    main()

其中需要注意将一个xml文件转为value数组时,需要根据实际xml的情况设定member对应的位置,这里使用的是4。转换的图片文件放在images/train和images/test目录下,运行完成后会在data目录下生成train_label.csv和test_label.csv。

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow.compat.v1 as tf;

from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict

flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS

# The index of the first class has to be 1!
# Do not use 0. DO NOT!
def class_text_to_int(row_label):
#这里XXXXXX为标记图片时所使用的标签名称
    if row_label == 'XXXXXX':
        return 1
    else:
        return 0;

def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]

def create_tf_example(group, path):
    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size

    filename = group.filename.encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(row['xmin'] / width)
        xmaxs.append(row['xmax'] / width)
        ymins.append(row['ymin'] / height)
        ymaxs.append(row['ymax'] / height)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

def main(_):
    writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
    #运行前将YYYYYY改为train或test
    path = os.path.join(os.getcwd(), 'images/YYYYY/')
    examples = pd.read_csv(FLAGS.csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())
    writer.close()
    output_path = os.path.join(os.getcwd(), FLAGS.output_path)
    print('Successfully created the TFRecords: {}'.format(output_path))

if __name__ == '__main__':
    tf.app.run()

使用方法如下,运行后将得到train.record和test.record两个文件

python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record

4. 定义模型参数

模型参数可以参考ssd_mobilenet_v2_coco.config,并对其中一些部分做修改,因为这里只需要检测自定义的一种物体,为此将num_classes修改为1,另外将fine_tune_checkpoint、train_input_reader,eval_input_reader中的相关路径改为实际对应的路径。

另外可以根据计算机的性能适当调整batch_size的大小,更大的batch_size会导致内存占用的大幅上升,但能够提升计算的收敛速度。

5. 训练和导出模型

完成以上准备后可以开始运行训练过程,参数分别为训练使用的配置,训练模型保存的路径,训练的步数和进行效果预估的步数

python model_main.py \
    --pipeline_config_path=training/ssd_mobilenet_v2_coco.config \
    --model_dir=training \
    --num_train_steps=15000 \
    --num_eval_steps=2000 \
    --alsologtostderr

完成训练后需要对训练的模型进行导出,导出的命令如下,其中model.ckpt选择对应目录中后缀值最大的一个,导出后会得到模型文件result/frozen_inference_graph.pb

python3 export_inference_graph.py \
    --input_type=image_tensor \
    --pipeline_config_path=training/ssd_mobilenet_v2_coco.config \
    --trained_checkpoint_prefix=training/model.ckpt-15000 \
    --output_directory=result

6. 使用模型进行物体检测

在完成以上全部步骤后就可以开始对图片中的物体进行检测了,检测的脚本如下,检测的图片保存为in.jpg,运行后会输出进行标记的文件out.jpg

import numpy as np
import tensorflow as tf
from utils import label_map_util
from utils import visualization_utils as vis_util
import cv2

PATH_TO_CKPT = 'result/frozen_inference_graph.pb'
PATH_TO_LABELS = 'training/pascal_label_map.pbtxt';
NUM_CLASSES = 1

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        image_np = cv2.imread('in.jpg');
        image_np_expanded = np.expand_dims(image_np, axis=0)
        (boxes, scores, classes, num) = sess.run(
            [detection_boxes, detection_scores, detection_classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        print(scores);
        vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8)
        cv2.imwrite('out.jpg', image_np)

实现了一个检测图面中哆啦A梦的模型,其检测的效果如下,总体上还是可以的。

Android中使用蓝牙通信

蓝牙(Bluetooth)是一种无线数据和语音通信开放的全球规范,基于低成本的近距离无线连接,为固定和移动设备建立通信环境的一种特殊的近距离无线技术连接。相比WIFI,蓝牙的优势在于连接不需要借助路由器或AP设备,可通过简单配对过程实现设备之间的快速连接。

以下将通过将Android系统作为服务端实现一个基于蓝牙的Socket服务器,客户端通过蓝牙连接后与其实现Socket通信。

1. Android的Service

为了让蓝牙服务在启动后一直能够保持运行状态,需要使用Android中提供的Service,其能够在后台也保持运行状态。如实现的BluetoothService需要继承自Service类,并覆盖onBind和onStartCommand两个方法来实现服务的逻辑。

public class BluetoothService extends Service {
    public IBinder onBind(Intent intent) {
        return null;
    }
    public int onStartCommand(Intent intent, int flags, int startId) {
        return START_STICKY;
    }
}

在Activity中可以启动Service,启动的代码如下

startService(new Intent(getBaseContext(), BluetoothService.class));

2. 启动蓝牙

首先需要在App中设置访问设备蓝牙的权限,在AndroidManifest.xml中增加以下权限申请

    <uses-permission android:name="android.permission.BLUETOOTH"/>
    <uses-permission android:name="android.permission.BLUETOOTH_ADMIN"/>

使用BluetoothAdapter类中的getDefaultAdapter方法可以获得系统的蓝牙适配器,如果系统没有蓝牙设备将会返回null,需要对此进行判断防止发生空指针异常。正常获得系统的蓝牙适配器后调用enable方法开启系统的蓝牙功能

BluetoothAdapter adapter = BluetoothAdapter.getDefaultAdapter();
if (adapter == null) {
    return -1;
}
//这里可以判断是否已经开启,不过重复开启也不会导致程序出现错误
adapter.enable();

为了让其他设备能够发现此蓝牙服务,需要请求开启蓝牙的可发现性,使用申请提示待用户需要点击确认后,可启动蓝牙的可发现性一段时间(可设置,最大300秒)

Intent intent = new Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE);
//早期版本的Android,支持设置0时长期保持可发现状态
intent.putExtra(BluetoothAdapter.EXTRA_DISCOVERABLE_DURATION, 300);
startActivity(intent);

3. 发现其他蓝牙设备

调用蓝牙适配器的startDiscovery方法可以搜索附近的蓝牙设备,但从Android 6开始启动蓝牙搜索需要获取定位的权限,同时定位的权限是不支持通过AndroidManifest中进行预先申请的,需要使用申请提示待用户需要点击确认,在Activity中增加以下代码

ActivityCompat.requestPermissions(this,
        new String[]{Manifest.permission.ACCESS_COARSE_LOCATION, 
        Manifest.permission.ACCESS_FINE_LOCATION, 
        Manifest.permission.ACCESS_BACKGROUND_LOCATION}, 1);

startDiscovery方法为异步运行,若想要在发现设备后执行相应的操作,需要设置回调类

//过滤器,仅关注发现事件
IntentFilter filter = new IntentFilter();
filter.addAction(BluetoothDevice.ACTION_FOUND);
//注册回调类
registerReceiver(new BroadcastReceiver() {
    //实现回调函数,在发现设备后输出设备的名称或地址
    @Override
    public void onReceive(Context context, Intent intent) {
        BluetoothDevice device = intent.getParcelableExtra(BluetoothDevice.EXTRA_DEVICE);
        if (device != null) {
            //当没有设置设备名称是getName方法将会返回null
            String name = device.getName();
            Log.i("bluetooth", name != null ? name : device.toString());
        }
    }
}, filter);

4. 实现蓝牙的SocketServer

实现与普通的TCP SocketServer非常类似,首先通过listen方法得到一个Socket,然后循环调用accept方法等待连接,当连接成功后获得输入输出流实现接收和发送数据,并通过一个线程来启动Server

new Thread(new Runnable() {
    @Override
    public void run() {
        try {
            UUID uuid = UUID.fromString("b383ee98-fd44-4972-8e3e-48bf542b0a7b");
            //设置名称和uid来供客户端连接时进行识别
            BluetoothServerSocket socket = adapter.listenUsingRfcommWithServiceRecord("Rfc Server", uuid);
            Log.i("bluetooth", uuid.toString());
            Log.i("bluetooth", adapter.getName());
            Log.i("bluetooth", adapter.getAddress());
            while (true) {
                BluetoothSocket client = socket.accept();
                Log.i("bluetooth", "Accept");
                if (client != null) {
                    OutputStream os = client.getOutputStream();
                    PrintStream out = new PrintStream(os);
                    out.println("Hello!");
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}).start();

5. 通过RPC与蓝牙Service通信

可以在蓝牙Service中再启动一个HTTP Server,然后通过HTTP进行RPC,同时在接到RPC调用后需要使用建立的蓝牙连接则需要对蓝牙Server和HTTP Server之间对连接操作进行同步。大致方法就是将建立的蓝牙连接保存在一个Set容器中,将此容器做为两个线程的共享对象,并以此进行操作连接的同步。具体实现如下

//实例化共享的连接集合
final Set<BluetoothSocket> clients = new HashSet<BluetoothSocket>();

//在蓝牙服务线程中
while (true) {
    BluetoothSocket client = socket.accept();
    Log.i("bluetooth", "Accept");
    if (client != null) {
        //进行同步
        synchronized (clients) {
            clients.add(client);
            OutputStream os = client.getOutputStream();
            PrintStream out = new PrintStream(os);
            out.println("Hello from Bluetooth!");
        }
    }
}

//在HTTP服务中
AsyncHttpServer server = new AsyncHttpServer();
server.get("/", new HttpServerRequestCallback() {
    @Override
    public void onRequest(AsyncHttpServerRequest request, AsyncHttpServerResponse response) {
        response.send("Hello!!!");
        //进行同步
        synchronized (clients) {
            //对每个连接发送消息
            for (BluetoothSocket c : clients) {
                try {
                    OutputStream os = c.getOutputStream();
                    PrintStream out = new PrintStream(os);
                    out.println("Hello from Http!");
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
});

6. 实现蓝牙的Client

这里使用Python实现了一个蓝牙的Client,使用pybluez库作为操作蓝牙的基础库,连接完成后接收数据,并关闭连接

import bluetooth;
sock = bluetooth.BluetoothSocket(bluetooth.RFCOMM);
host = 'D4:12:43:66:C4:A3';
port = 3;
sock.connect((host, port));
print(sock.recv(1024));
print(sock.recv(1024));
sock.close();

Client将会从Server收到一个Hello from Bluetooth!的字符串,并输出,同时通过HTTP调用HTTP Server,会再输出一个Hello from Http!

7. 总结

总体感觉Android的坑还是挺多的,各个版本之间的差异还是比较明显的,同时在相关的文档中又没有很清楚的说明,如果需要适配多个版本的Android系统的额外工作量还是比较大的,同时需要的各种访问权限需要在UI中手动确认,作为服务程序每次启动和重启都要手动点击确认确实还是不太方便。