-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
65 lines (64 loc) · 2.03 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from datetime import datetime
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
CSV_PATH = u'{}/data'.format(BASE_PATH)
CSV_PATH_DS3 = u'{}/data/ds3'.format(BASE_PATH)
IMG_PATH = u'{}/data/img'.format(BASE_PATH)
IMG_PATH_DS3 = u'{}/data/ds3/img'.format(BASE_PATH)
HDF5_PATH = u'{}/data/hdf5'.format(BASE_PATH)
HDF5_PATH_DS3 = u'{}/data/hdf5/ds3'.format(BASE_PATH)
MODEL_PATH = u'{}/models'.format(BASE_PATH)
IMAGE_DIM = (160, 70) #changed from (320, 70)
TIMESTEPS = 15
GPU_ID_TO_USE = '1'
SKLEARN_TRAIN_TEST_SPLIT_SIZE = 0.2
SKLEARN_RANDOM_STATE = 42
SAMPLE_WEIGHT = None
WEIGHT_FOR_CCE = 1
WEIGHT_FOR_MSE = 0.001
WEIGHT_FOR_ST = 5
NOTIMESTEPS_INPUT_SHAPE = (70,160,1)
TIMESTEPS_INPUT_SHAPE = (15,70,160,1)
CONV2D_FILTERS_1 = 24
CONV2D_FILTERS_2 = 36
CONV2D_FILTERS_3 = 48
CONV2D_FILTERS_4 = 64
CONV2D_FILTERS_5 = 80
CONV2D_FILTERS_6 = 96
KERNEL_SIZE_1 = (5, 5)
KERNEL_SIZE_2 = (3, 3)
STRIDE_DIM_1 = (2, 2)
STRIDE_DIM_2 = (1, 1)
PADDING = 'same'
CONV2D_ACTIVATION_FN = 'relu'
LSTM_OUTPUT_UNITS = 100
LSTM_ACTIVATION_FN = 'tanh'
LSTM_RETURN_SEQ = False
DROPOUT_VALUE = 0.5
DENSE_HIDDEN_UNITS_1 = 100
DENSE_HIDDEN_UNITS_2 = 50
DENSE_HIDDEN_UNITS_3 = 10
DENSE_ACTIVATION_FN = 'relu'
DENSE_OUTPUT_ACTIVATION_FN_STEERING = 'tanh'
DENSE_OUTPUT_ACTIVATION_FN_VELOCITY = 'relu'
DENSE_OUTPUT_ACTIVATION_FN_ACCELERATION = 'tanh'
DENSE_OUTPUT_ACTIVATION_FN_CLASSIFICATION = 'softmax'
MODEL_LEARNING_RATE = 1e-04
MODEL_LEARNING_DECAY = 0.0
MODEL_LOSS_FN1 = 'mse'
MODEL_LOSS_FN2 = 'cce'
PLOT_MODEL_SAVE_FILE = 'evaluate_ts15_ds3_Class6_cseg2.png'
PLOT_MODEL_SHOW_SHAPES = True
CALLBACKS_MONITOR = 'val_loss'
CALLBACKS_MONITOR_MODE = 'min'
SAVE_FORMAT = datetime.now().strftime("%Y-%m-%dT%H:%M")
EARLYSTOPPING_PATIENCE = 15
MODEL_CHECKPOINT_FILENAME = 'model_shuffled_2LSTM_tanh.h5'
TENSORBOARD_LOG_FILENAME = "ts1_shuffled_2LSTM_tanh_"
TENSORBOARD_LOG_PATH = u'{}tb_logs/'.format(BASE_PATH) + TENSORBOARD_LOG_FILENAME + SAVE_FORMAT
CALLBACKS_VERBOSITY = 1
MODEL_FIT_VERBOSITY = 1
MODEL_CHECKPOINT_SAVE_BEST = True
MODEL_FIT_SHUFFLE = True
TRAINING_EPOCH = 50
BATCH_SIZE = 128