Skip to content

Commit 4752e3a

Browse files
author
WangGuobao
committed
first commit
1 parent 75e8359 commit 4752e3a

6 files changed

+910
-0
lines changed

.gitignore

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.ipynb_checkpoints
2+
training_checkpoints
3+
cifar-10-batches-py
4+
5+
*.npz
6+
*.png
7+
*.gif
8+
*.jpg
9+
*.gz
10+
*.zip
11+
*.rar

Keras_CNN_MNIST.ipynb

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os\n",
10+
"import sys\n",
11+
"import numpy as np\n",
12+
"import keras as K\n",
13+
"import pickle\n",
14+
"import tarfile\n",
15+
"from urllib.request import urlretrieve\n",
16+
"from keras.models import Sequential\n",
17+
"from keras.layers import Dense, Dropout, Flatten\n",
18+
"from keras.layers import Conv2D, MaxPooling2D, Dropout\n",
19+
"from sklearn.preprocessing import OneHotEncoder"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"### Set env"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"# TensorFlow,Theano,CNTK\n",
36+
"os.environ['KERAS_BACKEND'] = \"tensorflow\" #Use TF1,some incompatibilities with TF2.\n",
37+
"# Force one-gpu\n",
38+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
39+
"# Performance Improvement\n",
40+
"# Make sure channels-first (not last)\n",
41+
"K.backend.set_image_data_format('channels_first')"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"### Load dataset"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"def read_batch(src):\n",
58+
" '''Unpack the pickle files'''\n",
59+
" with open(src, 'rb') as f:\n",
60+
" if sys.version_info.major == 2:\n",
61+
" data = pickle.load(f)\n",
62+
" else:\n",
63+
" data = pickle.load(f, encoding='latin1') # Contains the numpy array\n",
64+
" return data\n",
65+
"\n",
66+
"def process_cifar():\n",
67+
" '''Read data into RAM'''\n",
68+
" print('Preparing train set...')\n",
69+
" train_list = [read_batch('./cifar-10-batches-py/data_batch_{0}'.format(i + 1)) for i in range(5)]\n",
70+
" x_train = np.concatenate([x['data'] for x in train_list])\n",
71+
" y_train = np.concatenate([y['labels'] for y in train_list])\n",
72+
" print('Preparing test set...')\n",
73+
" tst = read_batch('./cifar-10-batches-py/test_batch')\n",
74+
" x_test = tst['data']\n",
75+
" y_test = np.asarray(tst['labels'])\n",
76+
" return x_train, x_test, y_train, y_test\n",
77+
"\n",
78+
"def maybe_download_cifar(src=\"http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\"):\n",
79+
" '''Load the training and testing data'''\n",
80+
" try:\n",
81+
" return process_cifar()\n",
82+
" except:\n",
83+
" # Catch the exception that file doesn't exist & Download\n",
84+
" print('Data does not exist. Downloading ' + src)\n",
85+
" filename = src.split('/')[-1]\n",
86+
" filepath = os.path.join(\"./\",filename)\n",
87+
" def _recall_func(num,block_size,total_size):\n",
88+
" sys.stdout.write('\\r>> downloading %s %.1f%%' % (filename,float(num*block_size)/float(total_size)*100.0))\n",
89+
" sys.stdout.flush()\n",
90+
" fname, h = urlretrieve(src, filepath,_recall_func)\n",
91+
" file_info = os.stat(filepath)\n",
92+
" print('Successfully download.',filename,file_info.st_size,'bytes')\n",
93+
" print('Extracting files...')\n",
94+
" with tarfile.open(fname) as tar:\n",
95+
" tar.extractall()\n",
96+
" os.remove(fname)\n",
97+
" return process_cifar()\n",
98+
" \n",
99+
"def cifar_for_library(channel_first=True, one_hot=False):\n",
100+
" # Raw data\n",
101+
" x_train, x_test, y_train, y_test = maybe_download_cifar()\n",
102+
" # Scale pixel intensity\n",
103+
" x_train = x_train / 255.0\n",
104+
" x_test = x_test / 255.0\n",
105+
" # Reshape\n",
106+
" x_train = x_train.reshape(-1, 3, 32, 32)\n",
107+
" x_test = x_test.reshape(-1, 3, 32, 32)\n",
108+
" # Channel last\n",
109+
" if not channel_first:\n",
110+
" x_train = np.swapaxes(x_train, 1, 3)\n",
111+
" x_test = np.swapaxes(x_test, 1, 3)\n",
112+
" # One-hot encode y\n",
113+
" if one_hot:\n",
114+
" y_train = np.expand_dims(y_train, axis=-1)\n",
115+
" y_test = np.expand_dims(y_test, axis=-1)\n",
116+
" enc = OneHotEncoder(categorical_features='all')\n",
117+
" fit = enc.fit(y_train)\n",
118+
" y_train = fit.transform(y_train).toarray()\n",
119+
" y_test = fit.transform(y_test).toarray()\n",
120+
" # dtypes\n",
121+
" x_train = x_train.astype(np.float32)\n",
122+
" x_test = x_test.astype(np.float32)\n",
123+
" y_train = y_train.astype(np.int32)\n",
124+
" y_test = y_test.astype(np.int32)\n",
125+
" return x_train, x_test, y_train, y_test\n",
126+
"\n",
127+
"# Data into format for library\n",
128+
"x_train, x_test, y_train, y_test = cifar_for_library(channel_first=True, one_hot=True)\n",
129+
"print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)\n",
130+
"print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)"
131+
]
132+
},
133+
{
134+
"cell_type": "markdown",
135+
"metadata": {},
136+
"source": [
137+
"### Init model"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"# Hyperparams\n",
147+
"EPOCHS = 10\n",
148+
"BATCHSIZE = 64\n",
149+
"LR = 0.01\n",
150+
"MOMENTUM = 0.9\n",
151+
"N_CLASSES = 10\n",
152+
"GPU = True\n",
153+
"BATCH_SIZE = 32\n",
154+
"\n",
155+
"def create_model(n_classes=N_CLASSES):\n",
156+
" model = Sequential()\n",
157+
" model.add(Conv2D(50, kernel_size=(3, 3), padding='same', activation='relu',\n",
158+
" input_shape=(3, 32, 32)))\n",
159+
" model.add(Conv2D(50, kernel_size=(3, 3), padding='same', activation='relu')) \n",
160+
" model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
161+
" model.add(Dropout(0.25))\n",
162+
" \n",
163+
" model.add(Conv2D(100, kernel_size=(3, 3), padding='same', activation='relu'))\n",
164+
" model.add(Conv2D(100, kernel_size=(3, 3), padding='same', activation='relu')) \n",
165+
" model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
166+
" model.add(Dropout(0.25))\n",
167+
" \n",
168+
" model.add(Flatten())\n",
169+
" model.add(Dense(512, activation='relu'))\n",
170+
" model.add(Dropout(0.5))\n",
171+
" model.add(Dense(n_classes, activation='softmax'))\n",
172+
" return model\n",
173+
"\n",
174+
"def init_model(m, lr=LR, momentum=MOMENTUM):\n",
175+
" m.compile(\n",
176+
" loss = \"categorical_crossentropy\",\n",
177+
" optimizer = K.optimizers.SGD(lr, momentum),\n",
178+
" metrics = ['accuracy'])\n",
179+
" return m\n",
180+
"\n",
181+
"model = create_model()\n",
182+
"model = init_model(model)\n",
183+
"model.summary()"
184+
]
185+
},
186+
{
187+
"cell_type": "markdown",
188+
"metadata": {},
189+
"source": [
190+
"### Main training loop"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": null,
196+
"metadata": {},
197+
"outputs": [],
198+
"source": [
199+
"model.fit(x_train,\n",
200+
" y_train,\n",
201+
" batch_size=BATCHSIZE,\n",
202+
" epochs=EPOCHS,\n",
203+
" verbose=1)"
204+
]
205+
},
206+
{
207+
"cell_type": "markdown",
208+
"metadata": {},
209+
"source": [
210+
"### Main evaluation loop"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": null,
216+
"metadata": {},
217+
"outputs": [],
218+
"source": [
219+
"y_guess = model.predict(x_test, batch_size=BATCHSIZE)\n",
220+
"y_guess = np.argmax(y_guess, axis=-1)\n",
221+
"y_truth = np.argmax(y_test, axis=-1)\n",
222+
"print(\"Accuracy: \", 1.*sum(y_guess == y_truth)/len(y_guess))"
223+
]
224+
}
225+
],
226+
"metadata": {
227+
"kernelspec": {
228+
"display_name": "Python 3",
229+
"language": "python",
230+
"name": "python3"
231+
},
232+
"language_info": {
233+
"codemirror_mode": {
234+
"name": "ipython",
235+
"version": 3
236+
},
237+
"file_extension": ".py",
238+
"mimetype": "text/x-python",
239+
"name": "python",
240+
"nbconvert_exporter": "python",
241+
"pygments_lexer": "ipython3",
242+
"version": "3.6.6"
243+
}
244+
},
245+
"nbformat": 4,
246+
"nbformat_minor": 4
247+
}

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
11
# DeepLearning-ETS
22
Deep learning easy to start.(TF1,TF2,Pytorch,Keras)(Interactive presentation online).
3+
4+
# Start
5+
- python 3.6
6+
- pip install -r requirements.txt
7+
## Native
8+
- pip install jupyterlab
9+
- jupyter-lab
10+
- [http://localhost:8888](http://localhost:8888)
11+
## Online
12+
- [https://mybinder.org/](https://mybinder.org/)

0 commit comments

Comments
 (0)