Commit 19620bea authored by Ardillen66's avatar Ardillen66
Browse files

Save and load baxter agent decision tree

parent c76c38d5
......@@ -28,15 +28,14 @@ class BaxterAgent(object):
self.tree = DecisionTreeGR()
self.train()
else :
training_samples = self.load_trainig_samples()
self.tree = DecisionTreeGR(training_samples)
self.tree = self.load_decision_tree()
self.run()
def train(self):
self.drawTrainingUI()
while self.is_train:
######
#TODO# take snapshot and extract features if object was added to the scene. Update scene accordingly
#TODO# take snapshot and extract features if object was added to the scene. Update scene accordingly, send image to baxte face
######
features = {}
right_button = rospy.wait_for_message('/robot/digital_io/right_shoulder_button/state', DigitalIOState, 3)
......@@ -55,12 +54,11 @@ class BaxterAgent(object):
elif left_nav.buttons[1] or right_nav.buttons[1]:
self.is_train = False
self.init_move() #Go back to default position
self.save_training_samples() #Store trained samples to retrain decision tree
self.tree.train()
self.save_decision_tree() #Store trained decision tree
def run(self):
self.drawProductionUI()
#For now we retrain decision tree from saved training data, but maybe better to save trained tree?
self.tree.train()
run = True
while run:
#TODO# Check if object on table and retrieve features
......@@ -79,15 +77,11 @@ class BaxterAgent(object):
self.init_move()
def load_decision_tree(self):
return self.tree.load(self.config['tree_dir'])
def load_trainig_samples(self):
#TODO#: load training samles from file specified in config
pass
def save_training_samples(self):
#TODO#: save decision tree training samples to file specified in config
pass
def save_decision_tree(self):
self.tree.save(self.config['tree_dir'])
def init_scene(self):
#TODO#: initialize scene with description in config
......@@ -102,7 +96,8 @@ class BaxterAgent(object):
rospy.sleep(3)
def drawTrainingUI(self):
#TODO#: draw training interface on baxter face screen
print 'Garbagebot now training. Please choose a side to sort after putting down an object'
xdisplay_image.send_image('path/to/initial/image')
def drawProductionUI(self):
#TODO#
......
......@@ -2,6 +2,7 @@
"""Decision tree implementation for the Garbage Robot."""
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals import joblib
DIRECTION = 'direction'
RIGHT_STR = 'RIGHT'
......@@ -128,6 +129,12 @@ class DecisionTreeGR(object):
if res == LEFT:
return LEFT_STR
def save(self, directory):
joblib.dump(self.__decision_tree, directory + "decision_tree.pkl")
def load(self, directory):
joblib.load(directory + "descision_tree.pkl")
# USAGE EXAMPLE
# dt = DecisionTreeGR()
# ts = {
......
......@@ -6,7 +6,7 @@ from baxter_agent import BaxterAgent
config = {
'is_train': True,
'tree_dir': './saved_trees',
'tree_dir': './saved_trees/',
'scene': 'scene description' #Should contain a description of the scene to be loaded
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment