Text classification in R is fun. While there are many possible approaches this blog post proposes a Keras (with Tensorflow backend) workflow based on vocabulary of lemmatized tokens.
The workflow proposed is particularly advantageous for text in languages other than English, with complicated inflected forms - remember ROMANES EUNT DOMUS scene from the Life of Brian?
The tokenisation and lemmatisation is performed using R package udpipe
. The main classification work is done via keras
, with light help of caret
and with dplyr
pipes used to hold everything together.
library(tidyverse) # mainly dplyr
library(udpipe) # tokenization & lemmatization
library(caret) # used to split the dataset to train / test
library(keras) # because Keras :)
As a “problem” to solve I will train my model to determine autorship of Tweet data, using timelines of two popular accounts.
I will be using a dataset of 1000 tweets; half each by Hadley Wickham and Marie Kondo. It contains lots and lots of tidines, and some sparkling with joy.
The dataset was collected with help of rtweet
library and can be downloaded by running below quoted code.
It is also necessary to download the udpipe model. I am storing it in tempdir()
, but in a real world situation it will make more sense to store it in a permanent place. The models tend to be large.
# directory for source data files
network_path <- 'http://www.jla-data.net/ENG/2019-01-25-vocabulary-based-text-classification_files/'
tf_tweets <- tempfile(fileext = ".csv") # create a temporary csv file
download.file(paste0(network_path, 'tidy_tweets.csv'), tf_tweets, quiet = T)
tweets <- read.csv(tf_tweets, stringsAsFactors = F) # read the tweet data in
# download current udpipe model for English
udtarget <- udpipe_download_model(language = "english",
model_dir = tempdir())
# load the model
udmodel <- udpipe_load_model(file = udtarget$file_model)
table(tweets$name) # verify structure of downloaded tweets
##
## hadleywickham MarieKondo
## 500 500
The first step is applying the udpipe_annotate()
function to the text of tweets. I am storing the id of the tweet as document id, enabling me to join the tweet data to individual words and lemmas later.
words <- udpipe_annotate(udmodel, x = tweets$text, doc_id = tweets$id) %>%
as.data.frame() %>%
select(id = doc_id, token, lemma, upos, sentence_id) %>%
mutate(id = as.numeric(id))
# a sneak peak at the words object
glimpse(words)
## Rows: 24,179
## Columns: 5
## $ id <dbl> 1.431262e+18, 1.431262e+18, 1.431262e+18, 1.431262e+18, 1.…
## $ token <chr> "@_TimTaylor_", "And", "board_url", "(", ")", "will", "war…
## $ lemma <chr> "@_TimTaylor_", "and", "board_url", "(", ")", "will", "war…
## $ upos <chr> "SYM", "CCONJ", "NOUN", "PUNCT", "PUNCT", "AUX", "VERB", "…
## $ sentence_id <int> 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2…
The second step is construction of vocabulary. It is a simple data frame containg two fields: lemma, and id. It will be used to translate the lemmas to integers, to be fed to the keras model.
vocabulary <- words %>%
count(lemma) %>%
ungroup() %>%
arrange(desc(n)) %>%
filter(n >=3) %>% # little predictive value in rare words
mutate(id_slovo = row_number()) %>% # unique id per lemma
select(lemma, id_slovo)
# a sneak peak at the vocabulary object
glimpse(vocabulary)
## Rows: 988
## Columns: 2
## $ lemma <chr> ".", "the", "be", "to", "you", ",", "a", ":", "!", "I", "and"…
## $ id_slovo <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18…
The third part is preparing keras input matrix. It starts as a dataframe (later converted to matrix) with 150 numerical columns, padded with zeroes. I yet have to see a 150 word tweet, so the size has some reserve.
This is a rather complicated pipe, and while it does the work required it looks as a candidate for refactoring (I have a vague feeling that I am not properly using the fill
argument of the spread()
function).
# 150 zeroes for each tweet id for padding
vata <- expand.grid(id = unique(words$id),
word_num = 1:150,
id_slovo = 0)
word_matrix <- words %>% # words
# filtering join! words not in vocabulary are discarded
inner_join(vocabulary, by = c('lemma' = 'lemma')) %>%
select(id, lemma, id_slovo) %>%
group_by(id) %>%
mutate(word_num = row_number()) %>% #
ungroup() %>%
select(id, word_num, id_slovo) %>% # relevant columns
rbind(vata) %>% # bind the 150 zeroes per tweet
group_by(id, word_num) %>%
mutate(id_slovo = max(id_slovo)) %>% # will include duplicites
ungroup() %>%
unique() %>% # remove duplicites
spread(word_num, id_slovo) # spread to matrix format
keras_input <- tweets %>%
select(id, name, text) %>%
inner_join(word_matrix, by = c('id' = 'id'))
# this will be input for keras
glimpse(keras_input)
## Rows: 998
## Columns: 153
## $ id <dbl> 1.431262e+18, 1.267958e+18, 1.273371e+18, 1.415105e+18, 1.383026…
## $ name <chr> "hadleywickham", "hadleywickham", "hadleywickham", "hadleywickha…
## $ text <chr> "@_TimTaylor_ And board_url() will warn if you use it with a non…
## $ `1` <dbl> 11, 91, 17, 13, 13, 25, 10, 371, 13, 4, 23, 20, 125, 766, 54, 13…
## $ `2` <dbl> 20, 9, 3, 8, 8, 508, 218, 23, 8, 47, 5, 10, 155, 36, 872, 8, 295…
## $ `3` <dbl> 21, 0, 2, 83, 50, 30, 5, 673, 2, 142, 281, 30, 63, 10, 16, 2, 10…
## $ `4` <dbl> 57, 0, 20, 3, 23, 7, 390, 0, 177, 2, 112, 469, 71, 112, 350, 177…
## $ `5` <dbl> 39, 0, 14, 58, 30, 166, 2, 0, 12, 343, 61, 4, 7, 412, 578, 12, 2…
## $ `6` <dbl> 5, 0, 734, 27, 7, 18, 476, 0, 33, 6, 179, 25, 365, 31, 8, 33, 4,…
## $ `7` <dbl> 42, 0, 14, 304, 89, 96, 4, 0, 89, 5, 0, 28, 19, 0, 266, 189, 484…
## $ `8` <dbl> 17, 0, 21, 9, 151, 5, 223, 0, 14, 229, 0, 814, 31, 0, 35, 12, 5,…
## $ `9` <dbl> 28, 0, 981, 255, 64, 63, 170, 0, 739, 92, 0, 32, 0, 0, 2, 44, 22…
## $ `10` <dbl> 7, 0, 22, 413, 56, 42, 436, 0, 74, 4, 0, 554, 0, 0, 14, 14, 321,…
## $ `11` <dbl> 899, 0, 10, 28, 40, 31, 0, 0, 8, 145, 0, 180, 0, 0, 645, 724, 20…
## $ `12` <dbl> 0, 0, 87, 7, 82, 20, 0, 0, 6, 4, 0, 18, 0, 0, 14, 14, 114, 186, …
## $ `13` <dbl> 0, 0, 119, 655, 20, 16, 0, 0, 6, 7, 0, 17, 0, 0, 224, 3, 679, 1,…
## $ `14` <dbl> 0, 0, 53, 27, 220, 951, 0, 0, 11, 6, 0, 4, 0, 0, 8, 55, 42, 27, …
## $ `15` <dbl> 0, 0, 2, 622, 922, 101, 0, 0, 403, 48, 0, 393, 0, 0, 266, 9, 6, …
## $ `16` <dbl> 0, 0, 702, 457, 21, 203, 0, 0, 16, 39, 0, 1, 0, 0, 8, 128, 5, 6,…
## $ `17` <dbl> 0, 0, 74, 4, 65, 21, 0, 0, 83, 17, 0, 17, 0, 0, 266, 0, 207, 137…
## $ `18` <dbl> 0, 0, 179, 11, 11, 16, 0, 0, 14, 3, 0, 3, 0, 0, 359, 0, 63, 3, 0…
## $ `19` <dbl> 0, 0, 0, 449, 10, 83, 0, 0, 3, 323, 0, 328, 0, 0, 964, 0, 75, 7,…
## $ `20` <dbl> 0, 0, 0, 1, 3, 6, 0, 0, 58, 5, 0, 4, 0, 0, 16, 0, 511, 1, 0, 20,…
## $ `21` <dbl> 0, 0, 0, 0, 69, 20, 0, 0, 237, 25, 0, 43, 0, 0, 8, 0, 1, 0, 0, 2…
## $ `22` <dbl> 0, 0, 0, 0, 17, 21, 0, 0, 4, 16, 0, 18, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ `23` <dbl> 0, 0, 0, 0, 16, 3, 0, 0, 1, 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 171, 0…
## $ `24` <dbl> 0, 0, 0, 0, 11, 926, 0, 0, 0, 130, 0, 385, 0, 0, 0, 0, 0, 0, 0, …
## $ `25` <dbl> 0, 0, 0, 0, 17, 20, 0, 0, 0, 12, 0, 28, 0, 0, 0, 0, 0, 0, 0, 7, …
## $ `26` <dbl> 0, 0, 0, 0, 3, 118, 0, 0, 0, 5, 0, 50, 0, 0, 0, 0, 0, 0, 0, 217,…
## $ `27` <dbl> 0, 0, 0, 0, 1, 6, 0, 0, 0, 315, 0, 168, 0, 0, 0, 0, 0, 0, 0, 144…
## $ `28` <dbl> 0, 0, 0, 0, 0, 150, 0, 0, 0, 6, 0, 11, 0, 0, 0, 0, 0, 0, 0, 4, 0…
## $ `29` <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 265, 0, 527, 0, 0, 0, 0, 0, 0, 0, 43…
## $ `30` <dbl> 0, 0, 0, 0, 0, 891, 0, 0, 0, 6, 0, 21, 0, 0, 0, 0, 0, 0, 0, 11, …
## $ `31` <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 292, 0…
## $ `32` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 315, …
## $ `33` <dbl> 0, 0, 0, 0, 0, 186, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0…
## $ `34` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, …
## $ `35` <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,…
## $ `36` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, …
## $ `37` <dbl> 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53, 0,…
## $ `38` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, …
## $ `39` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 380, 0…
## $ `40` <dbl> 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 668, 0,…
## $ `41` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0,…
## $ `42` <dbl> 0, 0, 0, 0, 0, 926, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0…
## $ `43` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74, 0,…
## $ `44` <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 43, 0…
## $ `45` <dbl> 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0…
## $ `46` <dbl> 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 83, 0…
## $ `47` <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `48` <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `49` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `50` <dbl> 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `51` <dbl> 0, 0, 0, 0, 0, 213, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `52` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `53` <dbl> 0, 0, 0, 0, 0, 186, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `54` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `55` <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `56` <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `57` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `58` <dbl> 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `59` <dbl> 0, 0, 0, 0, 0, 213, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `60` <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `61` <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `62` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `63` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `64` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `65` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `66` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `67` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `68` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `69` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `70` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `71` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `72` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `73` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `74` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `75` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `76` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `77` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `78` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `79` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `80` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `81` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `82` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `83` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `84` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `85` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `86` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `87` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `88` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `89` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `90` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `91` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `92` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `93` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `94` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `95` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `96` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `97` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `98` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `99` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `100` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `101` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `102` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `103` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `104` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `105` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `106` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `107` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `108` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `109` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `110` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `111` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `112` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `113` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `114` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `115` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `116` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `117` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `118` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `119` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `120` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `121` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `122` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `123` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `124` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `125` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `126` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `127` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `128` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `129` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `130` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `131` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `132` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `133` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `134` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `135` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `136` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `137` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `138` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `139` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `140` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `141` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `142` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `143` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `144` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `145` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `146` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `147` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `148` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `149` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `150` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
Once the full keras input is prepared I split it into training and testing parts.
set.seed(42) # Zaphod Beeblebrox advises to trust no other!
idx <- createDataPartition(keras_input$name, p = .8, list = F, times = 1) # 80 / 20 split
train_data <- keras_input[idx, ] # train dataset
test_data <- keras_input[-idx, ] # verification
With the training dataset ready a couple technical steps are necessary: converting the categorical field of author to a number, and conveting the whole input to a numerical matrix. The vocabulary size variable will be used to size keras input.
train_data <- train_data %>%
mutate(hadley = ifelse(name == 'hadleywickham', 1,0)) %>% # binary output
select(-id, -name, -text)
x_train <- data.matrix(train_data %>% select(-hadley)) # everything except target
y_train <- data.matrix(train_data %>% select(hadley)) # target, and target only
vocab_size <- vocabulary %>% # count unique word ids
pull(id_slovo) %>%
unique() %>%
length() + 1 # one extra for the zero padding
Once all the inputs are ready I define & train my model.
The model uses an embedding layer of the size of the vocabulary, and one bidirectional LSTM layer. As the LSTM’s are notoriously prone to overfitting a dropout layer is necessary.
The final layer is a single output node, representing the probability of a particular tweet being Hadley’s.
model <- keras_model_sequential()
model %>%
layer_embedding(input_dim = vocab_size, output_dim = 256) %>%
bidirectional(layer_lstm(units = 128)) %>%
layer_dropout(rate = 0.5) %>%
layer_dense(units = 1, activation = 'sigmoid') # 1 = Hadley, 0 = Marie
model %>%
compile(optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("accuracy"))
history <- model %>% # fit the model (this will take a while...)
fit(x_train,
y_train,
epochs = 25,
batch_size = nrow(train_data)/5,
validation_split = 1/5)
summary(model)
## Model: "sequential"
## ________________________________________________________________________________
## Layer (type) Output Shape Param #
## ================================================================================
## embedding (Embedding) (None, None, 256) 253184
## bidirectional (Bidirectional) (None, 256) 394240
## dropout (Dropout) (None, 256) 0
## dense (Dense) (None, 1) 257
## ================================================================================
## Total params: 647,681
## Trainable params: 647,681
## Non-trainable params: 0
## ________________________________________________________________________________
With the model trained I verify its effectiveness on the testing dataset, using the same pre–processing as for the training data.
pred_data <- test_data %>% # expected results
select(id, name, text)
test_data <- test_data %>%
mutate(hadley = ifelse(name == 'hadleywickham', 1,0)) %>% # binary output
select(-id, -name, -text) # no cheating!
x_pred <- data.matrix(test_data %>% select(-hadley)) # keras needs matrix
pred <- model %>% # let keras sweat...
predict(x_pred)
verifikace <- pred_data %>% # correct results ...
cbind(pred) # ... joined with predictions
verifikace <- verifikace %>%
mutate(name_pred = ifelse(pred > 0.5, 'hadleywickham', 'MarieKondo'))
conf_mtx <- table(verifikace$name, verifikace$name_pred)
print(paste0('Correctly predicted ',
sum(diag(conf_mtx)), ' of ',
sum(conf_mtx), ' tweets, which means ',
round(100 * sum(diag(conf_mtx))/sum(conf_mtx), 2),
'% of the total.'))
## [1] "Correctly predicted 180 of 199 tweets, which means 90.45% of the total."
Having accurary over 90% in a toy example such as this is not bad :)
It can be further improved in a number of ways - the two most obvious are including more training data (500 tweets is not that many), or adding more layers to create a more complicated model.