Skip to content

Commit 301a3d9

Browse files
committed
第一次提交
0 parents  commit 301a3d9

16 files changed

+1637
-0
lines changed

.gitignore

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Created by .ignore support plugin (hsz.mobi)
2+
### Example user template template
3+
### Example user template
4+
5+
# IntelliJ project files
6+
.idea
7+
*.iml
8+
out
9+
gen### Python template
10+
# Byte-compiled / optimized / DLL files
11+
__pycache__/
12+
*.py[cod]
13+
*$py.class
14+
15+
# C extensions
16+
*.so
17+
18+
# Distribution / packaging
19+
.Python
20+
build/
21+
develop-eggs/
22+
dist/
23+
downloads/
24+
eggs/
25+
.eggs/
26+
lib/
27+
lib64/
28+
parts/
29+
sdist/
30+
var/
31+
wheels/
32+
*.egg-info/
33+
.installed.cfg
34+
*.egg
35+
MANIFEST
36+
37+
# PyInstaller
38+
# Usually these files are written by a python script from a template
39+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
40+
*.manifest
41+
*.spec
42+
43+
# Installer logs
44+
pip-log.txt
45+
pip-delete-this-directory.txt
46+
47+
# Unit test / coverage reports
48+
htmlcov/
49+
.tox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
.hypothesis/
57+
.pytest_cache/
58+
59+
# Translations
60+
*.mo
61+
*.pot
62+
63+
# Django stuff:
64+
*.log
65+
local_settings.py
66+
db.sqlite3
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# pyenv
85+
.python-version
86+
87+
# celery beat schedule file
88+
celerybeat-schedule
89+
90+
# SageMath parsed files
91+
*.sage.py
92+
93+
# Environments
94+
.env
95+
.venv
96+
env/
97+
venv/
98+
ENV/
99+
env.bak/
100+
venv.bak/
101+
102+
# Spyder project settings
103+
.spyderproject
104+
.spyproject
105+
106+
# Rope project settings
107+
.ropeproject
108+
109+
# mkdocs documentation
110+
/site
111+
112+
# mypy
113+
.mypy_cache/
114+
/data

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2019 OpenSourceAI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# 简介
2+
本项目是对[yolov3的tensorflow实现](https://github.com/YunYang1994/tensorflow-yolov3)项目的"整合"吧,做了一些细微的修改,添加大量的中文注释,帮助进行快速阅读理解. 基础好的可以直接阅读原代码.
3+
4+
[yolov3的tensorflow实现](https://github.com/YunYang1994/tensorflow-yolov3)这个项目,应该是作为菜鸟的我到目前为止在原理和代码实现上最复杂的深度学习项目了. 项目代码量大,shape变换,维度广播,看着看着一不小心就迷失了,反反复复的看了好几遍,感觉才把整个项目代码的逻辑给拉通,整个过程反复调试,计算维度变换,运算的处理过程,总之收获巨大.
5+
6+
欢迎交流,指出错误等.
7+
# 开箱即用
8+
9+
下载[data]()
10+
11+
```
12+
$ python video_dome.py # 默认使用0摄像头, 也可以通过局域网调用手机摄像头
13+
```
14+
![截图]()
15+
# 学习
16+
17+
通过快速训练[quick_train.py]()开始,阅读项目代码开始学习yolov3的细节. 在之前
18+
- 下载[data](),使用浣熊数据集
19+
20+
![]()|![]()
21+
- [pic_vis.py] 可视化数据
22+
- 使用[core.convert_tfrecord.py](),转换为tfrecord文件
23+
- [show_image_from_tfrecord.py](),检查文件是否正常
24+
- [quick_train.py]()开始训练调试
25+
- [show_train_result.py]() 检测所训练的模型效果.
26+
27+
# 使用其他数据集进行训练
28+
待更新....
29+
30+
>https://github.com/YunYang1994/tensorflow-yolov3

core/common.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import tensorflow as tf
2+
3+
# 构建模型的基本组件
4+
slim = tf.contrib.slim
5+
6+
7+
def _conv2d_fixed_padding(inputs, filters, kernel_size, strides=1):
8+
if strides > 1: inputs = _fixed_padding(inputs, kernel_size)
9+
inputs = slim.conv2d(inputs, filters, kernel_size, stride=strides,
10+
padding=('SAME' if strides == 1 else 'VALID'))
11+
return inputs
12+
13+
14+
@tf.contrib.framework.add_arg_scope
15+
def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
16+
"""
17+
演空间维度填充输入,与输入大小无关, 只有与所使用的卷积核有关,左右两边进行填充
18+
19+
Args:
20+
inputs: A tensor of size [batch, channels, height_in, width_in] or
21+
[batch, height_in, width_in, channels] depending on data_format.
22+
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
23+
Should be a positive integer.
24+
mode: The mode for tf.pad.
25+
26+
Returns:
27+
A tensor with the same format as the input with the data either intact
28+
(if kernel_size == 1) or padded (if kernel_size > 1).
29+
"""
30+
# 使得kernel完整走过边缘
31+
pad_total = kernel_size - 1
32+
pad_beg = pad_total // 2
33+
pad_end = pad_total - pad_beg
34+
35+
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
36+
[pad_beg, pad_end], [0, 0]], mode=mode)
37+
return padded_inputs

core/convert_tfrecord.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import sys
2+
import argparse
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
7+
# 将训练图片转换为tfrecord文件
8+
9+
def main(argv):
10+
parser = argparse.ArgumentParser()
11+
# 物体映射表 , 图片地址, boxe , class_id 文件
12+
parser.add_argument("--dataset_txt", default='../data/train_dome_data/new_test.txt')
13+
parser.add_argument("--tfrecord_path_prefix",
14+
default='../data/train_dome_data/images')
15+
# default='./data/train_data/quick_train_data/tfrecords/quick_train_data')
16+
flags = parser.parse_args()
17+
18+
dataset = {}
19+
with open(flags.dataset_txt, 'r') as f:
20+
for line in f.readlines():
21+
example = line.split(' ')
22+
image_path = example[0]
23+
boxes_num = len(example[1:]) // 5 # boxs数量
24+
boxes = np.zeros([boxes_num, 5], dtype=np.float32)
25+
for i in range(boxes_num):
26+
boxes[i] = example[1 + i * 5:6 + i * 5]
27+
# print(boxes[i])
28+
dataset[image_path] = boxes
29+
30+
image_paths = list(dataset.keys())
31+
images_num = len(image_paths)
32+
print(">> Processing %d images" % images_num)
33+
34+
tfrecord_file = flags.tfrecord_path_prefix + "_" + flags.dataset_txt.split("_")[-1].split(".")[0] + ".tfrecords"
35+
with tf.python_io.TFRecordWriter(tfrecord_file) as record_writer:
36+
for i in range(images_num):
37+
image = tf.gfile.FastGFile(image_paths[i], 'rb').read() # 读取除二进制文件
38+
boxes = dataset[image_paths[i]] # 得到图片的boxes
39+
boxes = boxes.tostring() # 转出string
40+
41+
example = tf.train.Example(features=tf.train.Features(
42+
feature={
43+
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
44+
'boxes': tf.train.Feature(bytes_list=tf.train.BytesList(value=[boxes])),
45+
}
46+
))
47+
sys.stdout.write("\r>> %d / %d" % (i + 1, images_num))
48+
sys.stdout.flush()
49+
record_writer.write(example.SerializeToString())
50+
print(">> Saving %d images in %s" % (images_num, tfrecord_file))
51+
52+
53+
if __name__ == "__main__": main(sys.argv[1:])

0 commit comments

Comments
 (0)