Decision Trees are one of the most loved 😘 classification algorithms in the world of Machine Learning. They are used for both regression and classification. The most fundamental idea behind a decision tree is to, first, find a root node which divides our dataset into homogenous datasets and repeat until we are left with samples belonging to the same class ( 100% homogeneity ).
With Python packages like Scikit Learn, they are easy to build and run, with a couple of lines of code,
from sklearn.tree import DecisionTreeClassifier iris = load_iris()X = pd.DataFrame(iris.data[:, :], columns = iris.feature_names[:])
y = pd.DataFrame(iris.target, columns =["Species"]) tree = DecisionTreeClassifier(max_depth = 2)tree.fit(X,y)
Which makes our life easier for working with Python 😎 at least. But the decision trees could have a number of practical applications and thereby they need to be implemented on various platforms, like Android! In this story, we’ll be implementing a decision tree ( for classification ) in Kotlin ( or Java ).
The Python implementation of this decision tree is adapted from the below stories by Rakend Dubba and I will highly recommend them,
The sample data is also inherited from the stories above.
The code is available on GitHub with a sample APK file-> https://github.com/shubham0204/Decision_Tree_Android
First, as we all know, Kotlin doesn’t have a nice package like Pandas 🐼 to hold our dataset. We can create a simple DataFrame
class which holds our data internally in the form of HashMap<String,ArrayList<String>>
where the key represents the name of our column and ArrayList<String>
holds the data for each sample in the form of String
.
For simplicity, we keep all our features as well as labels in the form of Strings.
Later on, we’ll parse data from this class while building our tree.
Next, we build some internal methods. These methods facilitate calculations and they are not available in Kotlin. The first one is argmax()
which returns the index of the greatest value present in the array. The second one, getFreq()
, which returns the count or frequency of each element present in the array. The third one logbase2
returns the logarithm with base 2 of the given number.
Now, in order to get the total entropy at the root node ( labels ), we use the below method.
We use Shannon’s Entropy,
The above entropy will be used for the calculation for the Information Gain,
The Kotlin method for calculating the root entropy E_label,
Next, we need to calculate E_feature for a specified feature. We would feed a featureName
and data
to the method and it will return the entropy for that feature.
I’ll explain the above method so as to have an intuition on what’s going on,
The labels
carry the labels for all samples. featureValues
gets the whole column with featureColumnName
from the dataset. We loop through each of the featureValues
( sorted one, no repetition of values! ). We define a variable inside the loop named entropy
. This entropy
is calculated by,
The numCount
variable refers to the number of samples which featureValue
and the corresponding label
.
These entropies are summed up and added to another featureEntropy
,
Next, we need to find a feature which gives us the highest value of Information Gain. The method is given a data
object for which the labelEntropy
and featureEntropy
to calculate the IG score and store them in an array. We use our argmax
function to get the index of the greatest IG score and we finally return the feature name.
The above code snippet returns a HashMap
which contains a sub-table sorted by featureName
and featureValue
.
Constructing the Tree Recursively.
The most important part of our algorithm. Now, we define a method createTree
which will be called recursively till we are left with homogenous datasets i.e with an entropy of 0.
Let us understand the above method.
- When the method is called initially,
inputTree
is null anddata
contains the whole dataset ( no subsets ). - We find the
highestIGFeatureName
and fetch distinct attributes of it. Since ourinputTree
is null, we create a branch in our tree at line no. 8. - The attributes which we fetched will now be used. We iterate through each of them, get a sub-table for that
attribute
. Following this, we create in our treep
on lines 21 and 22. Since we aren’t left with homogenous datasets, the method is called recursively at line 22. - After a number of recursions, the length of the
counts
array will become 1. This denotes that the samples present insubTable[ LABEL_COLUMN_NAME ]
are of the same kind i.e there entropy is 0. Here, we break the recursion at line 18. Also, we assignp[ attribute ]
the value ofclValue[0]
.
The above method will return an object of type HashMap<String,Any>
which represents our tree.
We employ a similar method to predict a label for a given sample. This method takes in the tree ( which we produced in snippet 7 ) and returns a String
label. This method is recursive too.
The method is called recursively until the object p[ value ]
isn’t a HashMap
. This means that we have reached the end of a branch where we would find our prediction ( label ).
Try running the algorithm in Kotlin and Python. Compare the entropies at each step and you will find that they equal, which indicates that our algorithm is working fine. Thanks for reading and Happy Machine Learning 😃!
Credit: BecomingHuman By: Shubham Panchal