Neural networks are all the rage these days. After reading a book on how to use gradient descent to fit a simple neural network via backpropagation, I decided to code it up by hand with the ultimate goal of having it classify some MNIST digits.
library(ArtisanalMachineLearning)
data(iris)
Preprocess the data into a binary response dataset. Also change the response to be {-1, 1}
due to the nature of the tanh
function.
iris = iris[1:100,]
iris$Species = as.numeric(iris$Species) - 2
iris$Species[iris$Species == 0] = 1
Split the data into training and vaildation.
validation_set = rbind(iris[1:10,], iris[91:100,])
validation_set$Species = NULL
validation_response = c(iris$Species[1:10], iris$Species[91:100])
data = iris[11:90,]
data$Species = NULL
response = iris$Species[11:90]
Train a network.
network = aml_neural_network(sizes=c(4, 15, 15, 4, 1),
learning_rate=.01,
data=data,
response=response,
epochs=100,
verbose=FALSE)
Inspect training and validation accuracy by naively classifying all predictions less than 0 to the -1
level and 1
otherwise.
preds = predict(network, data)
training_accuracy = sum((preds < 0) == (response < 0)) / length(preds)
print(paste("Training accuracy:", training_accuracy))
## [1] "Training accuracy: 1"
validation = predict(network, validation_set)
validation_accuracy = sum((validation < 0) == (validation_response < 0)) / length(validation)
print(paste("Validation accuracy:", validation_accuracy))
## [1] "Validation accuracy: 1"
I have written the smartest algorithm of all time. OR This is a highly separable very small data set and the network picked up on that immediately (see the k means vignette for proof of this).
Let’s read in the MNIST handwritten digits data and define a few helper functions. Code was largely robbed from here.
library(readr)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tidyr)
library(ggplot2)
library(viridis)
process_digit_data <- function(data_row){
data_row %>%
gather(pixel, value) %>%
tidyr::extract(pixel, "pixel", "(\\d+)", convert = TRUE) %>%
mutate(pixel = pixel - 2,
x = pixel %% 28,
y = 28 - pixel %/% 28)
}
plot_digit <- function(data_row){
image_df = process_digit_data(data_row)
image = ggplot(image_df, aes(x = x, y = y, fill = value)) +
geom_raster() +
coord_equal() +
scale_fill_viridis()
image
}
mnist_raw <- read_csv("https://pjreddie.com/media/files/mnist_train.csv",
col_names = FALSE,
col_types = cols())
response = unlist(mnist_raw[,1])
data = mnist_raw
data[,1] = NULL
Preprocess data to only the digits 0
and 7
and transform the response to be {-1, 1}
due to the nature of the tanh
function. Also split into training and test sets.
digits_to_keep = c(0,7)
data_sub = data[response %in% digits_to_keep,]
response_sub = response[response %in% digits_to_keep]
response_transformed = sapply(response_sub, function(x){ifelse(x == 0, -1, 1)})
training_indicators = sample(1:length(response_transformed), 10000)
test_indicator = !(1:nrow(data_sub) %in% training_indicators)
training_data = data_sub[training_indicators,]
training_response = response_transformed[training_indicators]
test_data = data_sub[test_indicator,]
If you’re not familiar with the MNIST data set, it’s a dataset of handwritten digit images. Here are a few examples.
plot_digit(training_data[sample(1:nrow(training_data), 1),])
plot_digit(training_data[sample(1:nrow(training_data), 1),])
plot_digit(training_data[sample(1:nrow(training_data), 1),])
plot_digit(training_data[sample(1:nrow(training_data), 1),])
plot_digit(training_data[sample(1:nrow(training_data), 1),])
The average of all 0 and 7 images looks like the following.
# Plot average 0 image
averages_frame = as.data.frame(t(apply((training_data[training_response == -1,]),MARGIN=2,FUN=mean)))
plot_digit(averages_frame)
# Plot average 7 image
averages_frame = as.data.frame(t(apply((training_data[training_response == 1,]),MARGIN=2,FUN=mean)))
plot_digit(averages_frame)
Now, let’s train a small neural network and inspect. Warning! This takes some time to run
small_network = aml_neural_network(c(784, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=10, verbose = FALSE)
Accuracy rates using a naive classification split at 0
# Training error
preds = predict(small_network, training_data)
training_error = sum((training_response < 0) == (preds < 0)) / length(preds)
print(paste("Training error:", training_error))
## [1] "Training error: 0.8011"
# Test error
test_preds = predict(small_network, test_data)
testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds)
print(paste("Testing error:", testing_error))
## [1] "Testing error: 0.797531992687386"
Train a big neural network and inspect. Warning! This takes a long time to run
big_network = aml_neural_network(c(784, 100, 100, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=100, verbose = FALSE)
Accuracy rates using a naive classification split at 0
# Training error
preds = predict(big_network, training_data)
training_error = sum((training_response < 0) == (preds < 0)) / length(preds)
print(paste("Training error:", training_error))
## [1] "Training error: 0.7282"
# Test error
test_preds = predict(big_network, test_data)
testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds)
print(paste("Testing error:", testing_error))
## [1] "Testing error: 0.744515539305302"
median_obs = which(preds == median(preds)) # Many observations are equal to the median
plot_digit(training_data[sample(median_obs, 1),])
plot_digit(training_data[sample(median_obs, 1),])
plot_digit(training_data[sample(median_obs, 1),])
plot_digit(training_data[sample(median_obs, 1),])
median_obs = which(test_preds == median(test_preds)) # Many observations are equal to the median
plot_digit(test_data[sample(median_obs, 1),])
plot_digit(test_data[sample(median_obs, 1),])
plot_digit(test_data[sample(median_obs, 1),])
plot_digit(test_data[sample(median_obs, 1),])
plot_digit(training_data[which.min(preds),])
plot_digit(training_data[which.max(preds),])
plot_digit(test_data[which.min(test_preds),])
plot_digit(test_data[which.max(test_preds),])
Image processing with neural networks is very hard. This very quick analysis is clearly picking up on some signal, but you can see the network struggle with digits that were written thicker than others, digits that were not centered in the middle of the grid, or a 7 digit written with a cross through the middle, etc. It’s possible that training larger networks and/or training with more epochs would help in this situation, but that intense tuning and wasting of time is frivolous. More sophisticated preprocessing techniques and/or other types of neural networks such as convolutional neural networks are necessary to tackle this problem appropriately.