Decision Trees are one of the most loved 😘 classification algorithms in the world of Machine Learning. They are used for both regression and classification. The most fundamental idea behind a decision tree is to, first, find a root node which divides our dataset into homogenous datasets and repeat until we are left with samples belonging to the same class ( 100% homogeneity ).
With Python packages like Scikit Learn, they are easy to build and run, with a couple of lines of code,
from sklearn.tree import DecisionTreeClassifier iris = load_iris()X = pd.DataFrame(iris.data[:, :], columns = iris.feature_names[:])
y = pd.DataFrame(iris.target, columns =["Species"]) tree = DecisionTreeClassifier(max_depth = 2)tree.fit(X,y)
Which makes our life easier for working with Python 😎 at least. But the decision trees could have a number of practical applications and thereby they need to be implemented on various platforms, like Android! In this story, we’ll be implementing a decision tree ( for classification ) in Kotlin ( or Java ).
The Python implementation of this decision tree is adapted from the below stories by Rakend Dubba and I will highly recommend them,
The sample data is also inherited from the stories above.
The code is available on GitHub with a sample APK file-> https://github.com/shubham0204/Decision_Tree_Android
First, as we all know, Kotlin doesn’t have a nice package like Pandas 🐼 to hold our dataset. We can create a simple
DataFrame class which holds our data internally in the form of
HashMap<String,ArrayList<String>> where the key represents the name of our column and
ArrayList<String> holds the data for each sample in the form of
For simplicity, we keep all our features as well as labels in the form of Strings.
Later on, we’ll parse data from this class while building our tree.
Next, we build some internal methods. These methods facilitate calculations and they are not available in Kotlin. The first one is
argmax() which returns the index of the greatest value present in the array. The second one,
getFreq() , which returns the count or frequency of each element present in the array. The third one
logbase2 returns the logarithm with base 2 of the given number.
Now, in order to get the total entropy at the root node ( labels ), we use the below method.
We use Shannon’s Entropy,
The above entropy will be used for the calculation for the Information Gain,
The Kotlin method for calculating the root entropy E_label,
Next, we need to calculate E_feature for a specified feature. We would feed a
data to the method and it will return the entropy for that feature.
I’ll explain the above method so as to have an intuition on what’s going on,
labels carry the labels for all samples.
featureValues gets the whole column with
featureColumnName from the dataset. We loop through each of the
featureValues ( sorted one, no repetition of values! ). We define a variable inside the loop named
entropy is calculated by,
numCount variable refers to the number of samples which
featureValue and the corresponding
These entropies are summed up and added to another
Next, we need to find a feature which gives us the highest value of Information Gain. The method is given a
data object for which the
featureEntropy to calculate the IG score and store them in an array. We use our
argmax function to get the index of the greatest IG score and we finally return the feature name.
The above code snippet returns a
HashMap which contains a sub-table sorted by
Constructing the Tree Recursively.
The most important part of our algorithm. Now, we define a method
createTree which will be called recursively till we are left with homogenous datasets i.e with an entropy of 0.
Let us understand the above method.
- When the method is called initially,
inputTreeis null and
datacontains the whole dataset ( no subsets ).
- We find the
highestIGFeatureNameand fetch distinct attributes of it. Since our
inputTreeis null, we create a branch in our tree at line no. 8.
- The attributes which we fetched will now be used. We iterate through each of them, get a sub-table for that
attribute. Following this, we create in our tree
pon lines 21 and 22. Since we aren’t left with homogenous datasets, the method is called recursively at line 22.
- After a number of recursions, the length of the
countsarray will become 1. This denotes that the samples present in
subTable[ LABEL_COLUMN_NAME ]are of the same kind i.e there entropy is 0. Here, we break the recursion at line 18. Also, we assign
p[ attribute ]the value of
The above method will return an object of type
HashMap<String,Any> which represents our tree.
We employ a similar method to predict a label for a given sample. This method takes in the tree ( which we produced in snippet 7 ) and returns a
String label. This method is recursive too.
The method is called recursively until the object
p[ value ]isn’t a
HashMap. This means that we have reached the end of a branch where we would find our prediction ( label ).
Try running the algorithm in Kotlin and Python. Compare the entropies at each step and you will find that they equal, which indicates that our algorithm is working fine. Thanks for reading and Happy Machine Learning 😃!