当前位置:主页 > 资料 >

[译] 使用 TensorFlow 一步步进行目标检测
栏目分类:资料   发布日期:2018-08-02   浏览次数:

导读:本文为去找网小编(www.7zhao.net)为您推荐的[译] 使用 TensorFlow 一步步进行目标检测,希望对您有所帮助,谢谢! 将检查点模型(.ckpt)保存为.pb文件 回到TensorFlow目标检测文件夹,并将expor

本文为去找网小编(www.7zhao.net)为您推荐的[译] 使用 TensorFlow 一步步进行目标检测,希望对您有所帮助,谢谢! www.7zhao.net



将检查点模型(.ckpt)保存为.pb文件

回到TensorFlow目标检测文件夹,并将export_inference_graph.py文件复制到包含模型配置文件的文件夹中。

本文来自去找www.7zhao.net

python export_inference_graph.py --input_type image_tensor --pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model 

去找(www.7zhao.net欢迎您

这将创建一个新目录 fine_tuned_model ,里面名为 frozen_inference_graph.pb 的模型就是您训练出来的模型。 copyright www.7zhao.net

在项目中使用模型

我在本教程中一直在研究的项目是创建一个红绿灯分类器。在Python中,我将此分类器实现为一个类。 在类的初始化部分,我创建了一个TensorFlow会话,这样就不需要在每次需要分类时创建它。

copyright www.7zhao.net

class TrafficLightClassifier(object):
    def __init__(self):
        PATH_TO_MODEL = 'frozen_inference_graph.pb'
        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            # Works up to here.
            with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
            self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
            self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
        self.sess = tf.Session(graph=self.detection_graph) 

www.7zhao.net

在该类中,我创建了一个函数,该函数对图像进行分类,并返回图像中分类的边界框、分数和类别。

本文来自去找www.7zhao.net

def get_classification(self, img):
    # Bounding Box Detection.
    with self.detection_graph.as_default():
        # Expand dimension since the model expects image to have shape [1, None, None, 3].
        img_expanded = np.expand_dims(img, axis=0)  
        (boxes, scores, classes, num) = self.sess.run(
            [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
            feed_dict={self.image_tensor: img_expanded})
    return boxes, scores, classes, num 内容来自www.7zhao.net 

此时,您需要过滤低于指定分数阈值的结果。结果会自动从最高分数到最低分数排序,因此这很容易实现。通过上面的函数返回分类结果,就是这样,您做到了! copyright www.7zhao.net

您可以在下图中看到我实现的红绿灯分类器。 本文来自去找www.7zhao.net

www.7zhao.net

我最初创建本教程是因为我很难找到有关如何使用Object Detection API的资讯。我希望通过阅读本教程,您可以启动项目,让项目快速实现,这样您可以将更多时间集中在您真正感兴趣的内容上!

copyright www.7zhao.net

www.7zhao.net


本文原文地址:https://mp.weixin.qq.com/s/qOeyocU2Yz7zZrpcerCTew

以上为[译] 使用 TensorFlow 一步步进行目标检测文章的全部内容,若您也有好的文章,欢迎与我们分享!

欢迎访问www.7zhao.net

Copyright ©2008-2017去找网版权所有   皖ICP备12002049号-2 皖公网安备 34088102000435号   关于我们|联系我们| 免责声明|友情链接|网站地图|手机版