To extract the decision rules from scikit-learn decision-tree try this code below:
from sklearn.tree import _tree
def tree_to_code(tree,fnames):
tree_ = tree.tree_
fnames = [
fnames[n] if n != _tree.TREE_UNDEFINED else "undefined!"
for n in tree_.feature
]
print "def tree({}):".format(", ".join(fnames))
def recurse(node, depth):
ind = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = fnames[node]
th = tree_.threshold[node]
print "{}if {} <= {}:".format(ind, name, th)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(ind, name, th)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(ind, tree_.value[node])
recurse(0, 1)
The above code prints a valid Python function.
Example: output of a tree which is trying to return a number between 0 to 10
def tree(m0):
if m0 <= 6.0:
if m0 <= 1.5:
return [[ 0.]]
else:
if m0 <= 4.5:
if m0 <= 3.5:
return [[ 3.]]
else:
return [[ 4.]]
else:
return [[ 5.]]
else:
if m0 <= 8.5:
if m0 <= 7.5:
return [[ 7.]]
else:
return [[ 8.]]
else:
return [[ 9.]]