Commit dced312d authored by Mathieu Reymond's avatar Mathieu Reymond
Browse files

separate network for color and shape

parent 73156285
import sys
import cv2
import numpy as np
import tensorflow as tf
import tflearn
from os import walk
from tqdm import tqdm
from imgaug import augmenters as iaa
from uuid import uuid4
from math import sqrt
def name_to_features(name):
color = {'RED': 0, 'BLUE': 1, 'GREEN': 2}
shape = {'CUBE': 0, 'SPHERE': 1, 'TRIANGLE': 2}
features = np.zeros(len(color.keys())*len(shape.keys()))
# features = np.zeros(len(color.keys())*len(shape.keys()))
colors = np.zeros(len(color.keys()))
shapes = np.zeros(len(shape.keys()))
color = [v for k, v in color.items() if name.startswith(k)][0]
shape = [v for k, v in shape.items() if name.endswith(k)][0]
features[color*3+shape] = 1
# features[color*3+shape] = 1
# return features
# print(str(color) + ' ' + str(shape) + ': ' + str(features))
return features
colors[color] = 1
shapes[shape] = 1
return np.append(colors, shapes)
def out_to_features(out):
color = ['RED', 'BLUE', 'GREEN']
shape = ['CUBE', 'SPHERE', 'TRIANGLE']
# outputs ranked according to probabilities,
# take the most probable one
print(out)
out = out[0]
c_i = out/3
s_i = out%3
......@@ -64,8 +73,8 @@ def augment(x, y, nb=20):
def preprocess_sample(img, crop_x=(210,430), crop_y=(0, 200)):
img = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
img = cv2.resize(img, (0,0), fx=0.5, fy=0.5)
img = cv2.GaussianBlur(img,(5,5),0)
# cv2.imwrite('test.jpg', img)
# img = cv2.GaussianBlur(img,(5,5),0)
cv2.imwrite('test.jpg', img)
img = img/255.
return img
......@@ -81,23 +90,32 @@ def preprocess(data):
return x, y
def create_model(x_shape=[None, 100, 110, 3], out=9, path='model/model.m'):
net = tflearn.input_data(shape=x_shape)
net = tflearn.conv_2d(net, 32, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.conv_2d(net, 32, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.conv_2d(net, 64, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.flatten(net)
net = tflearn.fully_connected(net, 64, activation='relu')
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, out, activation='sigmoid')
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy')
color = tflearn.input_data(shape=x_shape)
gray = tflearn.custom_layer(color, tf.image.rgb_to_grayscale)
final = []
for name, net in {'color': color, 'shape': gray}.items():
net = tflearn.conv_2d(net, 32, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.conv_2d(net, 32, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.conv_2d(net, 64, (3, 3), activation='relu')
net = tflearn.max_pool_2d(net, (2,2))
net = tflearn.flatten(net)
net = tflearn.fully_connected(net, 64, activation='relu')
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, out, activation='sigmoid')
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy')
final.append(net)
final = tflearn.merge(final, 'concat')
return tflearn.DNN(final, tensorboard_verbose=0) #, checkpoint_path=path)
return tflearn.DNN(net, tensorboard_verbose=0) #, checkpoint_path=path)
def train_net(model, x, y):
model.fit(x, y, n_epoch=60, shuffle=True, show_metric=True, batch_size=50)
def train_model(model, x, y):
model.fit(x, y, n_epoch=60, shuffle=True, show_metric=True, batch_size=50, run_id='gb')
outs = [y[:,3:], y[:,:3]]
train_net(model, x, outs)
def load_model(path, x_shape, out):
model = create_model(x_shape, out)
......@@ -121,12 +139,13 @@ if __name__ == '__main__':
s = [None]
s.extend(x.shape[1:])
model = create_model(x_shape=s, out=y.shape[1])
model = create_model(x_shape=s, out=y.shape[1]/2)
train_model(model, x, y)
model.save(path + '../checkpoint/model.m')
else:
model = load_model(model_path, [None, 100, 110, 3], 9)
model = load_model(model_path, [None, 100, 110, 3], 3)
img = cv2.imread(pred_path)
img = preprocess_sample(img)
print(model.predict([img]))
pred = model.predict_label([img])[0]
print(out_to_features(pred))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment