Commit 22b6b774 authored by Fabio Navarrete's avatar Fabio Navarrete
Browse files

Added wrapper module for decision trees

parent 8f1e027d
# -*- coding: utf-8 -*-
"""Decision tree implementation for the Garbage Robot."""
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
DIRECTION = 'direction'
RIGHT_STR = 'RIGHT'
LEFT_STR = 'LEFT'
RIGHT = 0
LEFT = 1
COLOR = 'color'
RED_STR = 'RED'
GREEN_STR = 'GREEN'
BLUE_STR = 'BLUE'
RED = 0
GREEN = 1
BLUE = 2
SIZE = 'size'
SMALL_STR = 'SMALL'
MEDIUM_STR = 'MEDIUM'
BIG_STR = 'BIG'
SMALL = 0
MEDIUM = 1
BIG = 2
SHAPE = 'shape'
CUBE_STR = 'CUBE'
SPHERE_STR = 'SPEHRE'
TRIANGLE_STR = 'TRIANGLE'
CUBE = 0
SPHERE = 1
TRIANGLE = 2
map_to_int = {
RIGHT_STR: RIGHT,
LEFT_STR: LEFT,
RED_STR: RED,
GREEN_STR: GREEN,
BLUE_STR: BLUE,
SMALL_STR: SMALL,
MEDIUM_STR: MEDIUM,
BIG_STR: BIG,
CUBE_STR: CUBE,
SPHERE_STR: SPHERE,
TRIANGLE_STR: TRIANGLE
}
class DecisionTreeGR(object):
"""Wrapper class for the SciKit Learn decision tree implementation."""
def __init__(self, training_samples=None):
"""Constructor.
Constructor method for the Decision Tree wraper class. If provided a
list of training samples would be used to train the classifier.
"""
self.__decision_tree = DecisionTreeClassifier(criterion='gini')
self.__training_samples = training_samples
if not self.__training_samples:
self.__training_samples = []
def add_training_sample(self, training_sample):
"""Add a training sample for the classifier.
Expected structure of the training sample in a dictionary form:
{
'color': color_value,
'shape': shape_value,
'size': size_value,
'direction': direction_value
}
where the values are strings with the following possible values:
color_value: 'RED', 'GREEN', 'BLUE'
shape_value: 'SPHERE', 'CUBE', 'TRIANGLE'
size_value: 'SMALL', 'BIG'
direction_value: 'LEFT', 'RIGHT'
"""
self.__training_samples.append(training_sample)
def train(self):
"""Train the tree classifier using the class training samples.
Tree samples used are expected to be provided in the constructor or by
using the add_training_sample method.
"""
df = pd.DataFrame(self.__training_samples)
df = self.__map_string_to_int(df)
features = [COLOR, SIZE, SHAPE]
y = df[DIRECTION]
X = df[features]
self.__decision_tree.fit(X, y)
def __map_string_to_int(self, data_frame):
"""Map string to integer for training dataframe."""
return data_frame.replace(map_to_int)
def classify(self, input):
"""Classify an input.
Input expected to have a dictionary format containing the following
structure:
{
'color': color_value,
'shape': shape_value,
'size': size_value
}
where the values are strings with the following possible values:
color_value: 'RED', 'GREEN', 'BLUE'
shape_value: 'SPHERE', 'CUBE', 'TRIANGLE'
size_value: 'SMALL', 'BIG'
Make sure that the tree is trained before calling this method (Call
train method before)
"""
input = {
key: [map_to_int[value]]
for (key, value) in input.iteritems()
}
res = self.__decision_tree.predict(pd.DataFrame(input))
if res == RIGHT:
return RIGHT_STR
if res == LEFT:
return LEFT_STR
# USAGE EXAMPLE
# dt = DecisionTreeGR()
# ts = {
# COLOR: GREEN_STR,
# SIZE: MEDIUM_STR,
# SHAPE: TRIANGLE_STR,
# DIRECTION: RIGHT_STR
# }
# dt.add_training_sample(ts)
# dt.train()
# inp = {COLOR: RED_STR, SIZE: BIG_STR, SHAPE: CUBE_STR}
# res = dt.classify(inp)
# print(res)
# -*- coding: utf-8 -*-
"""Tests for the decision tree implementation."""
from GarbageBot.decision_api import decision_tree_gr
# Implement proper unit testing here
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment