Neural Network Vignette

Jeremy Werner

2019-04-07

Neural networks are all the rage these days. After reading a book on how to use gradient descent to fit a simple neural network via backpropagation, I decided to code it up by hand with the ultimate goal of having it classify some MNIST digits.

Example 1: Iris

library(ArtisanalMachineLearning)
data(iris)

Preprocess the data into a binary response dataset. Also change the response to be {-1, 1} due to the nature of the tanh function.

iris = iris[1:100,]
iris$Species = as.numeric(iris$Species) - 2
iris$Species[iris$Species == 0] = 1

Split the data into training and vaildation.

validation_set = rbind(iris[1:10,], iris[91:100,])
validation_set$Species = NULL
validation_response = c(iris$Species[1:10], iris$Species[91:100])

data = iris[11:90,]
data$Species = NULL
response = iris$Species[11:90]

Train a network.

network = aml_neural_network(sizes=c(4, 15, 15, 4, 1),
                             learning_rate=.01, 
                             data=data, 
                             response=response, 
                             epochs=100, 
                             verbose=FALSE)

Inspect training and validation accuracy by naively classifying all predictions less than 0 to the -1 level and 1 otherwise.

preds = predict(network, data)
training_accuracy = sum((preds < 0) == (response < 0)) / length(preds)
print(paste("Training accuracy:", training_accuracy))
## [1] "Training accuracy: 1"
validation = predict(network, validation_set)
validation_accuracy = sum((validation < 0) == (validation_response < 0)) / length(validation)
print(paste("Validation accuracy:", validation_accuracy))
## [1] "Validation accuracy: 1"

Conclusion

I have written the smartest algorithm of all time. OR This is a highly separable very small data set and the network picked up on that immediately (see the k means vignette for proof of this).

Example 2: MNIST

Let’s read in the MNIST handwritten digits data and define a few helper functions. Code was largely robbed from here.

library(readr)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(tidyr)
library(ggplot2)
library(viridis)

process_digit_data <- function(data_row){
    data_row %>%
        gather(pixel, value) %>%
        tidyr::extract(pixel, "pixel", "(\\d+)", convert = TRUE) %>%
        mutate(pixel = pixel - 2,
               x = pixel %% 28,
               y = 28 - pixel %/% 28)
}

plot_digit <- function(data_row){
    image_df = process_digit_data(data_row)
    image = ggplot(image_df, aes(x = x, y = y, fill = value)) + 
                geom_raster() + 
                coord_equal() +
                scale_fill_viridis()
    image
}

mnist_raw <- read_csv("https://pjreddie.com/media/files/mnist_train.csv", 
                      col_names = FALSE,
                      col_types = cols())

response = unlist(mnist_raw[,1])

data = mnist_raw
data[,1] = NULL

Preprocess data to only the digits 0 and 7 and transform the response to be {-1, 1} due to the nature of the tanh function. Also split into training and test sets.

digits_to_keep = c(0,7)
data_sub = data[response %in% digits_to_keep,]
response_sub = response[response %in% digits_to_keep]
response_transformed = sapply(response_sub, function(x){ifelse(x == 0, -1, 1)})

training_indicators = sample(1:length(response_transformed), 10000)
test_indicator = !(1:nrow(data_sub) %in% training_indicators)

training_data = data_sub[training_indicators,]
training_response = response_transformed[training_indicators]

test_data = data_sub[test_indicator,]

If you’re not familiar with the MNIST data set, it’s a dataset of handwritten digit images. Here are a few examples.

plot_digit(training_data[sample(1:nrow(training_data), 1),])

plot_digit(training_data[sample(1:nrow(training_data), 1),])

plot_digit(training_data[sample(1:nrow(training_data), 1),])

plot_digit(training_data[sample(1:nrow(training_data), 1),])

plot_digit(training_data[sample(1:nrow(training_data), 1),])

The average of all 0 and 7 images looks like the following.

# Plot average 0 image
averages_frame = as.data.frame(t(apply((training_data[training_response == -1,]),MARGIN=2,FUN=mean)))
plot_digit(averages_frame)

# Plot average 7 image
averages_frame = as.data.frame(t(apply((training_data[training_response == 1,]),MARGIN=2,FUN=mean)))
plot_digit(averages_frame)

Now, let’s train a small neural network and inspect. Warning! This takes some time to run

small_network = aml_neural_network(c(784, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=10, verbose = FALSE)

Accuracy rates using a naive classification split at 0

# Training error
preds = predict(small_network, training_data)
training_error = sum((training_response < 0) == (preds < 0)) / length(preds)
print(paste("Training error:", training_error))
## [1] "Training error: 0.8011"
# Test error 
test_preds = predict(small_network, test_data)
testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds)
print(paste("Testing error:", testing_error))
## [1] "Testing error: 0.797531992687386"

Train a big neural network and inspect. Warning! This takes a long time to run

big_network = aml_neural_network(c(784, 100, 100, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=100, verbose = FALSE)

Accuracy rates using a naive classification split at 0

# Training error
preds = predict(big_network, training_data)
training_error = sum((training_response < 0) == (preds < 0)) / length(preds)
print(paste("Training error:", training_error))
## [1] "Training error: 0.7282"
# Test error 
test_preds = predict(big_network, test_data)
testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds)
print(paste("Testing error:", testing_error))
## [1] "Testing error: 0.744515539305302"

Investigation of Classification Extremes

Four equally uncertain random images from training data

median_obs = which(preds == median(preds)) # Many observations are equal to the median
plot_digit(training_data[sample(median_obs, 1),])

plot_digit(training_data[sample(median_obs, 1),])

plot_digit(training_data[sample(median_obs, 1),])

plot_digit(training_data[sample(median_obs, 1),])

Four equally uncertain random images from testing data

median_obs = which(test_preds == median(test_preds)) # Many observations are equal to the median
plot_digit(test_data[sample(median_obs, 1),])

plot_digit(test_data[sample(median_obs, 1),])

plot_digit(test_data[sample(median_obs, 1),])

plot_digit(test_data[sample(median_obs, 1),])

Most certain image classified as 0 in training set

plot_digit(training_data[which.min(preds),])

Most certain image classified as 7 in training set

plot_digit(training_data[which.max(preds),])

Most certain image classified as 0 in testing set

plot_digit(test_data[which.min(test_preds),])

Most certain image classified as 7 in testing set

plot_digit(test_data[which.max(test_preds),])

Conclusion

Image processing with neural networks is very hard. This very quick analysis is clearly picking up on some signal, but you can see the network struggle with digits that were written thicker than others, digits that were not centered in the middle of the grid, or a 7 digit written with a cross through the middle, etc. It’s possible that training larger networks and/or training with more epochs would help in this situation, but that intense tuning and wasting of time is frivolous. More sophisticated preprocessing techniques and/or other types of neural networks such as convolutional neural networks are necessary to tackle this problem appropriately.