Commit 9ff9b6a5 authored by Fabio Navarrete's avatar Fabio Navarrete
Browse files

Added possibility of changing features in decision tree constructor.

parent f701cff4
......@@ -52,7 +52,7 @@ map_to_int = {
class DecisionTreeGR(object):
"""Wrapper class for the SciKit Learn decision tree implementation."""
def __init__(self, training_samples=None):
def __init__(self, training_samples=None, features=[COLOR, SIZE, SHAPE]):
"""Constructor.
Constructor method for the Decision Tree wraper class. If provided a
......@@ -60,6 +60,7 @@ class DecisionTreeGR(object):
"""
self.__decision_tree = DecisionTreeClassifier(criterion='gini')
self.__training_samples = training_samples
self._features = features
if not self.__training_samples:
self.__training_samples = []
......@@ -89,9 +90,8 @@ class DecisionTreeGR(object):
"""
df = pd.DataFrame(self.__training_samples)
df = self.__map_string_to_int(df)
features = [COLOR, SIZE, SHAPE]
y = df[DIRECTION]
X = df[features]
X = df[self._features]
self.__decision_tree.fit(X, y)
def __map_string_to_int(self, data_frame):
......@@ -118,7 +118,6 @@ class DecisionTreeGR(object):
Make sure that the tree is trained before calling this method (Call
train method before)
"""
input = {
key: [map_to_int[value]]
for (key, value) in input.iteritems()
......@@ -130,9 +129,11 @@ class DecisionTreeGR(object):
return LEFT_STR
def save(self, directory):
"""Save the object in the current state to a serialized version pkl."""
joblib.dump(self.__decision_tree, directory + "decision_tree.pkl")
def load(self, directory):
"""Load a serialized tree from a pre-defined location."""
joblib.load(directory + "descision_tree.pkl")
# USAGE EXAMPLE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment