Decision Trees are one of the most loved 😘 classification algorithms in the world of Machine Learning. They are used for both regression and classification. The most fundamental idea behind a decision tree is to, first, find a root node which divides our dataset into homogenous datasets and repeat until we are left with samples belonging to the same class ( 100% homogeneity ).

With Python packages like Scikit Learn, they are easy to build and run, with a couple of lines of code,

from sklearn.tree import DecisionTreeClassifier iris = load_iris()X = pd.DataFrame(iris.data[:, :], columns = iris.feature_names[:])

y = pd.DataFrame(iris.target, columns =["Species"]) tree = DecisionTreeClassifier(max_depth = 2)tree.fit(X,y)

Which makes our life easier for working with Python 😎 at least. But the decision trees could have a number of practical applications and thereby they need to be implemented on various platforms, like Android! In this story, we’ll be implementing a decision tree ( for classification ) in Kotlin ( or Java ).

The Python implementation of this decision tree is adapted from the below stories by Rakend Dubba and I will highly recommend them,

The sample data is also inherited from the stories above.

The code is available on GitHub with a

sample APK file-> https://github.com/shubham0204/Decision_Tree_AndroidFirst, as we all know, Kotlin doesn’t have a nice package like Pandas 🐼 to hold our dataset. We can create a simple

`DataFrame`

class which holds our data internally in the form of`HashMap<String,ArrayList<String>>`

where the key represents the name of our column and`ArrayList<String>`

holds the data for each sample in the form of`String`

.For simplicity, we keep all our features as well as labels in the form of Strings.

Later on, we’ll parse data from this class while building our tree.

Next, we build some

internalmethods. These methods facilitate calculations and they are not available in Kotlin. The first one is`argmax()`

which returns the index of the greatest value present in the array. The second one,`getFreq()`

, which returns the count or frequency of each element present in the array. The third one`logbase2`

returns the logarithm with base 2 of the given number.Now, in order to get the total entropy at the root node ( labels ), we use the below method.

We use Shannon’s Entropy,

The above entropy will be used for the calculation for the Information Gain,

The Kotlin method for calculating the root entropy

E_label,Next, we need to calculate

E_featurefor a specified feature. We would feed a`featureName`

and`data`

to the method and it will return the entropy for that feature.I’ll explain the above method so as to have an intuition on what’s going on,

The

`labels`

carry the labels for all samples.`featureValues`

gets the whole column with`featureColumnName`

from the dataset. We loop through each of the`featureValues`

( sorted one, no repetition of values! ). We define a variable inside the loop named`entropy`

. This`entropy`

is calculated by,The

`numCount`

variable refers to the number of samples which`featureValue`

and the corresponding`label`

.These entropies are summed up and added to another

`featureEntropy`

,Next, we need to find a feature which gives us the highest value of Information Gain. The method is given a

`data`

object for which the`labelEntropy`

and`featureEntropy`

to calculate the IG score and store them in an array. We use our`argmax`

function to get the index of the greatest IG score and we finally return the feature name.The above code snippet returns a

`HashMap`

which contains asub-tablesorted by`featureName`

and`featureValue`

.## Constructing the Tree Recursively.

The most important part of our algorithm. Now, we define a method

`createTree`

which will be called recursively till we are left with homogenous datasets i.e with an entropy of 0.Let us understand the above method.

- When the method is called initially,
`inputTree`

is null and`data`

contains the whole dataset ( no subsets ).- We find the
`highestIGFeatureName`

and fetch distinct attributes of it. Since our`inputTree`

is null, we create a branch in our tree at line no. 8.- The attributes which we fetched will now be used. We iterate through each of them, get a sub-table for that
`attribute`

. Following this, we create in our tree`p`

on lines 21 and 22. Since we aren’t left with homogenous datasets, the method is called recursively at line 22.- After a number of recursions, the length of the
`counts`

array will become 1. This denotes that the samples present in`subTable[ LABEL_COLUMN_NAME ]`

are of the same kind i.e there entropy is 0. Here, we break the recursion at line 18. Also, we assign`p[ attribute ]`

the value of`clValue[0]`

.The above method will return an object of type

`HashMap<String,Any>`

which represents our tree.We employ a similar method to predict a label for a given sample. This method takes in the tree ( which we produced in snippet 7 ) and returns a

`String`

label. This method is recursive too.The method is called recursively until the object

`p[ value ]`

isn’t a`HashMap`

. This means that we have reached the end of a branch where we would find our prediction ( label ).

Credit: BecomingHuman By: Shubham Panchal