Basics, R, Supervised Learning

Introduction to Random Forest

Random Forests

Introduction: Random Forest

Now that we have an idea about decision trees and how exactly they work, I think we can now go a step further and try to improve our decision tree models by introducing a very basic but very effective extension for decision trees, which are popularly known as “Random Forest”. To understand decision trees in detail, you can refer to the following three links:

Coming back to the article, random forest are basically a collection of a number of decision trees and together they are used to give the final output. Like, decision trees random forest is also a supervised learning algorithm which can be used for both regression and classification problems. To get the prediction from a random forest, we use the output from each of the trees which we commonly call as “votes”. The final output is the one which has the most number of votes.

Algorithm for Random Forest

Random forest algorithm is similar to the decision tree algorithm (CART model), the only difference here is that here we build multiple decision trees randomly. To understand this, let’s take an example where we have 5000 observations with 10 predictors. The random forest algorithm here tries to build multiple decision trees with the training data with random samples and random predictors. For example, for first the tree, it will take 1000  random observations with only 7 random predictors. For next tree, it will again select randomly 3000 observations with randomly selected predictors. These numbers will also be selected randomly. It will continue to create as many of these trees as we want and then each of these trees will make their individual predictions.

For regression cases, the predicted value is simply the mean of predicted values from each of the tree, while for classification, we use the concept of votes. Each of the trees in the forest, predicts a class for the observation, in other words, it gives a vote to one of the classes and the class with maximum votes is the final output.

How each tree grows in a random forest:

  • Suppose we have a class with ‘N’ observations and ‘P’ predictors, then each tree will choose ‘n’ observations and ‘p’ predictors randomly.
  • Here n < N and p < P.
  • If there are ‘M input variables, a number m << M is specified such that at each node, ‘m’ variables are selected at random out of the ‘M’ and the best split on these ‘m‘ is used to split the node. The value of ‘m’ is held constant during the forest growing.
  • There is no pruning in random forest and each tree is allowed to grow to the maximum possible length.

Error rate in Random Forest

The error rate for the random forest algorithm depends mostly on two things:

  • If there is a high correlation between the trees in the forest, then the error rate is high.
  • The tree with low error rate is considered to be a strong classifier and increasing the strength of the tree with low error rate can reduce the error rate of the forest.

Random Forest features

  • It gives better accuracy than currently used algorithms.
  • They can handle large datasets efficiently
  • It takes into account all of the predictor variables (predictors are selected randomly for each of the trees)
  • It tells about the important variables in the dataset (Variable importance)
  • We can visualize the forest error rate as the forest grows and then decide on the size of the forest accordingly

Random Forest example in R

In the following example I’ve used UCI credit card dataset that can be downloaded from here: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients

# Loading the required libraries
library(caTools)
library(caret)
library(e1071) 
library(randomForest)  

Loading the dataset and removing rows where target values are not present

x <- read.csv("E:\\Abhay MBA Docs\\TERM 5\\Subjects\\Analytics Practicum\\UCI_Credit_Card.csv", sep = ",")

## Remove rows that do not have target variable values
final <- x[!(is.na(x$default.payment.next.month)),]

Splitting the dataset in the ratio 75:25 (train:test)

set.seed(88)
split <- sample.split(final$default.payment.next.month, SplitRatio = 0.75)

dresstrain <- subset(final, split == TRUE)
dresstest <- subset(final, split == FALSE)

Let’s factorize the output variable and then create random forest with 100 trees in it.

dresstrain$default.payment.next.month <- as.factor(dresstrain$default.payment.next.month)
dresstest$default.payment.next.month <- as.factor(dresstest$default.payment.next.month)

rf = randomForest(default.payment.next.month~.,  
                   ntree = 100,
                   data = dresstrain)
## Plot the error rate of the random forest as we increase the number of trees
plot(rf) 

Forest Error rate

We see that beyond 37-38 trees, the error rate stabilizes. So in the next iteration, we can reduce the number of trees to 40 as the error rate will be same for the random forest with 100 trees or with 40 trees.

## Let's see the importance of variables in the model in a tabular form
varImp(rf)
##             Overall
## ID        511.20584
## LIMIT_BAL 372.63264
## SEX        74.76610
## EDUCATION 132.16961
## MARRIAGE   91.27863
## AGE       401.43206
## PAY_0     717.82297
## PAY_2     355.23851
## PAY_3     224.47560
## PAY_4     163.23384
## PAY_5     143.28872
## PAY_6     164.07542
## BILL_AMT1 427.82303
## BILL_AMT2 387.65018
## BILL_AMT3 367.98343
## BILL_AMT4 361.55032
## BILL_AMT5 349.59847
## BILL_AMT6 353.40670
## PAY_AMT1  370.41753
## PAY_AMT2  343.84676
## PAY_AMT3  329.83085
## PAY_AMT4  302.63583
## PAY_AMT5  310.36254
## PAY_AMT6  329.43405

We can also plot these in a graph by using the below code:

## Important variables according to the model
varImpPlot(rf,  
           sort = T,
           n.var=25,
           main="Variable Importance")

Variable importance graph
Let us now predict the values in the test set and create a confusion matrix to check model’s performance.

predicted.response <- predict(rf, dresstest)


confusionMatrix(data=predicted.response,  
                reference=dresstest$default.payment.next.month)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 5497 1025
##          1  344  634
##                                           
##                Accuracy : 0.8175          
##                  95% CI : (0.8085, 0.8262)
##     No Information Rate : 0.7788          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.379           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9411          
##             Specificity : 0.3822          
##          Pos Pred Value : 0.8428          
##          Neg Pred Value : 0.6483          
##              Prevalence : 0.7788          
##          Detection Rate : 0.7329          
##    Detection Prevalence : 0.8696          
##       Balanced Accuracy : 0.6616          
##                                           
##        'Positive' Class : 0               
## 

We can further improve the accuracy by creating a random forest with fewer trees and including fewer important variables. Hope that helps!

Tagged , , ,