Loading the required packages
require(ISLR)
## Loading required package: ISLR
require(tree)
## Loading required package: tree
require(ggplot2)
## Loading required package: ggplot2
require(dplyr)
## Loading required package: dplyr
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
Loading the dataset - Carseats
attach(Carseats)
We will plot the distribution of Sales
ggplot(Carseats, aes(x = Sales)) +
geom_histogram(position = "dodge", binwidth = 2) +
theme_classic()
We will create a binary response variable High
based on Sales and our goal would be to build a model to predict this variable
Carseats$High <- ifelse(Carseats$Sales > 8, "Yes", "No")
Carseats$High <- as.factor(Carseats$High)
Decision Trees
We will now fit our tree model on the dataset to model response variable High
tree.fit <- tree(High ~ .-Sales, data = Carseats, y = TRUE)
summary(tree.fit)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats, y = TRUE)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
table(tree.fit$y, Carseats$High)
##
## No Yes
## No 236 0
## Yes 0 164
To get a detailed summary of tree, we can simply print the tree model object
(tree.fit)
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
Plotting the graphical representation of the tree
plot(tree.fit)
text(tree.fit, pretty = 0,cex = 0.6)
***
Estimating the test error rate
We will split the dataset into train and test. We will use train set to fit the model and then evaluate its performance using test set
set.seed(1011)
train_idx <- sample(1:nrow(Carseats), size = 250)
train.tree <- tree(High ~. - Sales, data = Carseats[train_idx,])
plot(train.tree); text(train.tree, pretty = 0, cex = 0.6)
summary(train.tree)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats[train_idx, ])
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "CompPrice" "Advertising"
## [6] "Education" "Income" "US"
## Number of terminal nodes: 23
## Residual mean deviance: 0.3498 = 79.4 / 227
## Misclassification error rate: 0.088 = 22 / 250
Predict the response variable on the test dataset
pred <- predict(train.tree, Carseats[-train_idx,], type = "class")
#We want the class label, so using type = "class"
table(Carseats$High[-train_idx], pred)
## pred
## No Yes
## No 58 27
## Yes 20 45
Miss-classification rate
(27+20)/(58+27+20+45)
## [1] 0.3133333
Tree Pruning
We will now use Cross validation to determine the optimal depth of the tree
#We want to prune the tree that was fully grown using the Training set
cv.tree.model <- cv.tree(train.tree, FUN = prune.misclass)
cv.tree.model
## $size
## [1] 23 17 16 14 10 8 6 5 4 2 1
##
## $dev
## [1] 62 62 60 61 69 81 79 81 81 81 100
##
## $k
## [1] -Inf 0.0 1.0 1.5 2.0 3.0 3.5 5.0 6.0 7.0 27.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
The summary above tells us for each size of the tree, what is the mean deviance value and what was the cost complexity parameter corresponding to it.
The mean deviance drops initially as the tree reduces in size and then it begins to increase as tree becomes more and more simple
plot(cv.tree.model)
Based on the above plot, we will select a tree with size 16 i.e. corresponding to the minimum misclassification error
cv.result <- as.data.frame(cbind(cv.tree.model$size, cv.tree.model$dev))
colnames(cv.result) <- c("size", "dev")
tree_size <- cv.result$size[cv.result$dev == min(cv.result$dev)]
We will fit the tree(of size 13) on the full training set
prune.tree.train <- prune.misclass(train.tree, best = tree_size) #best = #of terminal nodes
plot(prune.tree.train); text(prune.tree.train, pretty = 0, cex = 0.6)
Fit it on test dataset
prune.tree.pred <- predict(prune.tree.train, Carseats[-train_idx,], type = "class")
table(prune.tree.pred, Carseats$High[-train_idx])
##
## prune.tree.pred No Yes
## No 59 18
## Yes 26 47
Missclassification rate
((18+26)/(59+18+26+47))
## [1] 0.2933333
We observe that the misclassification rate corresponding to the pruned tree is lower than that of the full-grown tree.
Random Forest and Boosting
We will now illustrate the application of Random Forests and Boosting. The packages used are - randomForest and gbm. Here we will use the Boston Housing dataset available in the MASS
package
require(randomForest)
## Loading required package: randomForest
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
require(gbm)
## Loading required package: gbm
## Loaded gbm 2.1.5
require(MASS)
## Loading required package: MASS
##
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
##
## select
Random Forest
The concept of Random Forest is - We grow many full-size trees (having high variance and low bias) and we then average the outcome of these trees to make the prediction (thereby redcuing the variance)
set.seed(101)
dim(Boston)
## [1] 506 14
train_idx <- sample(nrow(Boston), 300)
We will now fit random forest to model the response variable medv
- the median housing values
rf.boston <- randomForest(medv ~ ., data = Boston, subset = train_idx)
rf.boston
##
## Call:
## randomForest(formula = medv ~ ., data = Boston, subset = train_idx)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 4
##
## Mean of squared residuals: 12.68651
## % Var explained: 83.45
From the above summary, we see that 500 bushy trees were built on the training subset. The Out-Of-Bag(OOB) mean sqaured residuals value is also displayed in the summary. This is sort of the de-biased estimate of the prediction error.
The number of variables avaiable at each split(mtry) = 4 out of the 13 predictor variables.
Slecting optimal value for mtry
The only tuning parameter in random forest is mtry
i.e number of predictors available for each split
oob.err <- double(13)
test.err <- double(13)
for(mtry in 1:13)
{
rf.fit <- randomForest(medv ~ ., data = Boston, subset = train_idx, mtry = mtry , ntree = 400)
oob.err[mtry] <- rf.fit$mse[400]
test.fit <- predict(rf.fit, Boston[-train_idx,])
test.err[mtry] <- with(Boston[-train_idx,], mean((medv-test.fit)^2))
cat(mtry," ")
}
## 1 2 3 4 5 6 7 8 9 10 11 12 13
matplot(1:mtry, cbind(test.err, oob.err), pch = 19, col = c("red","blue"), type = "b", ylab = "MSE")
legend("topright", legend = c("Test", "OOB"), pch = 19, col = c("red", "blue"))
We can see that the test error is minimum around mtry = 6.
Boosting
Unlike random forest, Bosoting grows the tree sequentially on a modified training set rather than on bootstrapped training sets
require(gbm)
boost.boston <- gbm(medv ~.,
data = Boston[train_idx,],
distribution = "gaussian",
n.trees = 10000,
shrinkage = 0.01,
interaction.depth = 4)
summary(boost.boston)
The 2 most important variables are lstat
and rm
We can view the partial dependence plot for these 2 variables.
plot(boost.boston, i = "lstat")
This plot roughly shows that higher the proportion of lower status people in suburb, lower is the value of the housing price.
plot(boost.boston, i = "rm")
This plot is quite intuitive. Higher the average number of rooms, higher is the housing price.
There are 3 parameters to be considered for boosting.
We can use CV to determine the optimal number of trees as well as shrinkage parameter
Here, we will look at the test performance as a function of number of trees.
n.trees <- seq(from = 100, to = 10000, by = 100)
test.pred <- predict(boost.boston, newdata = Boston[-train_idx,], n.trees = n.trees)
dim(test.pred)
## [1] 206 100
#For each of the 206 observations, 100 different prediction values are produced
test.err <- with(Boston[-train_idx,], apply((test.pred - medv)^2, 2, mean ))
#test.err is 100 different MSE's corresponding 100 different values of n.trees
plot(n.trees, test.err,
pch = 19, type = "b",
ylab = "MSE", xlab = "Number of trees(n.trees)",
main = "Boosting Test Error vs Numbe of trees in boosted ensemble")
The plot shows the test error as a function of number of trees. The plot seems to level off beyond n.trees = 2000. It slightly increases near n.treee = 9000, which may indicate overfitting.