Text classification in R is fun. While there are many possible approaches this blog post proposes a Keras (with Tensorflow backend) workflow based on vocabulary of lemmatized tokens.

The workflow proposed is particularly advantageous for text in languages other than English, with complicated inflected forms - remember ROMANES EUNT DOMUS scene from the Life of Brian?

The tokenisation and lemmatisation is performed using R package udpipe. The main classification work is done via keras, with light help of caret and with dplyr pipes used to hold everything together.

library(tidyverse) # mainly dplyr
library(udpipe)    # tokenization & lemmatization
library(caret)     # used to split the dataset to train / test
library(keras)     # because Keras :)

As a “problem” to solve I will train my model to determine autorship of Tweet data, using timelines of two popular accounts.

I will be using a dataset of 1000 tweets; half each by Hadley Wickham and Marie Kondo. It contains lots and lots of tidines, and some sparkling with joy.

The dataset was collected with help of rtweet library and can be downloaded by running below quoted code.

It is also necessary to download the udpipe model. I am storing it in tempdir(), but in a real world situation it will make more sense to store it in a permanent place. The models tend to be large.

# directory for source data files
network_path <- 'http://www.jla-data.net/ENG/2019-01-25-vocabulary-based-text-classification_files/'

tf_tweets <- tempfile(fileext = ".csv") # create a temporary csv file
download.file(paste0(network_path, 'tidy_tweets.csv'), tf_tweets, quiet = T)
tweets <- read.csv(tf_tweets, stringsAsFactors = F) # read the tweet data in

# download current udpipe model for English
udtarget <- udpipe_download_model(language = "english",
                                  model_dir = tempdir())
# load the model
udmodel <- udpipe_load_model(file = udtarget$file_model) 

table(tweets$name) # verify structure of downloaded tweets
## 
## hadleywickham    MarieKondo 
##           500           500

The first step is applying the udpipe_annotate() function to the text of tweets. I am storing the id of the tweet as document id, enabling me to join the tweet data to individual words and lemmas later.

words <- udpipe_annotate(udmodel, x = tweets$text, doc_id = tweets$id) %>% 
  as.data.frame() %>%
  select(id = doc_id, token, lemma, upos, sentence_id) %>%
  mutate(id = as.numeric(id))

# a sneak peak at the words object
glimpse(words)
## Rows: 24,179
## Columns: 5
## $ id          <dbl> 1.431262e+18, 1.431262e+18, 1.431262e+18, 1.431262e+18, 1.…
## $ token       <chr> "@_TimTaylor_", "And", "board_url", "(", ")", "will", "war…
## $ lemma       <chr> "@_TimTaylor_", "and", "board_url", "(", ")", "will", "war…
## $ upos        <chr> "SYM", "CCONJ", "NOUN", "PUNCT", "PUNCT", "AUX", "VERB", "…
## $ sentence_id <int> 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2…

The second step is construction of vocabulary. It is a simple data frame containg two fields: lemma, and id. It will be used to translate the lemmas to integers, to be fed to the keras model.

vocabulary <- words %>%
  count(lemma) %>%
  ungroup() %>%
  arrange(desc(n)) %>%
  filter(n >=3) %>%  # little predictive value in rare words
  mutate(id_slovo = row_number()) %>% # unique id per lemma
  select(lemma, id_slovo)

# a sneak peak at the vocabulary object
glimpse(vocabulary)
## Rows: 988
## Columns: 2
## $ lemma    <chr> ".", "the", "be", "to", "you", ",", "a", ":", "!", "I", "and"…
## $ id_slovo <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18…

The third part is preparing keras input matrix. It starts as a dataframe (later converted to matrix) with 150 numerical columns, padded with zeroes. I yet have to see a 150 word tweet, so the size has some reserve.

This is a rather complicated pipe, and while it does the work required it looks as a candidate for refactoring (I have a vague feeling that I am not properly using the fill argument of the spread() function).

# 150 zeroes for each tweet id for padding
vata <- expand.grid(id = unique(words$id),
                    word_num = 1:150,
                    id_slovo = 0)

word_matrix <- words %>% # words
  # filtering join! words not in vocabulary are discarded
  inner_join(vocabulary, by = c('lemma' = 'lemma')) %>% 
  select(id, lemma, id_slovo) %>%
  group_by(id) %>%
  mutate(word_num = row_number()) %>% # 
  ungroup() %>%
  select(id, word_num, id_slovo) %>% # relevant columns
  rbind(vata) %>% # bind the 150 zeroes per tweet
  group_by(id, word_num) %>%
  mutate(id_slovo = max(id_slovo)) %>% # will include duplicites
  ungroup() %>%
  unique() %>% # remove duplicites
  spread(word_num, id_slovo) # spread to matrix format

keras_input <- tweets %>%
  select(id, name, text) %>%
  inner_join(word_matrix, by = c('id' = 'id'))

# this will be input for keras
glimpse(keras_input)
## Rows: 998
## Columns: 153
## $ id    <dbl> 1.431262e+18, 1.267958e+18, 1.273371e+18, 1.415105e+18, 1.383026…
## $ name  <chr> "hadleywickham", "hadleywickham", "hadleywickham", "hadleywickha…
## $ text  <chr> "@_TimTaylor_ And board_url() will warn if you use it with a non…
## $ `1`   <dbl> 11, 91, 17, 13, 13, 25, 10, 371, 13, 4, 23, 20, 125, 766, 54, 13…
## $ `2`   <dbl> 20, 9, 3, 8, 8, 508, 218, 23, 8, 47, 5, 10, 155, 36, 872, 8, 295…
## $ `3`   <dbl> 21, 0, 2, 83, 50, 30, 5, 673, 2, 142, 281, 30, 63, 10, 16, 2, 10…
## $ `4`   <dbl> 57, 0, 20, 3, 23, 7, 390, 0, 177, 2, 112, 469, 71, 112, 350, 177…
## $ `5`   <dbl> 39, 0, 14, 58, 30, 166, 2, 0, 12, 343, 61, 4, 7, 412, 578, 12, 2…
## $ `6`   <dbl> 5, 0, 734, 27, 7, 18, 476, 0, 33, 6, 179, 25, 365, 31, 8, 33, 4,…
## $ `7`   <dbl> 42, 0, 14, 304, 89, 96, 4, 0, 89, 5, 0, 28, 19, 0, 266, 189, 484…
## $ `8`   <dbl> 17, 0, 21, 9, 151, 5, 223, 0, 14, 229, 0, 814, 31, 0, 35, 12, 5,…
## $ `9`   <dbl> 28, 0, 981, 255, 64, 63, 170, 0, 739, 92, 0, 32, 0, 0, 2, 44, 22…
## $ `10`  <dbl> 7, 0, 22, 413, 56, 42, 436, 0, 74, 4, 0, 554, 0, 0, 14, 14, 321,…
## $ `11`  <dbl> 899, 0, 10, 28, 40, 31, 0, 0, 8, 145, 0, 180, 0, 0, 645, 724, 20…
## $ `12`  <dbl> 0, 0, 87, 7, 82, 20, 0, 0, 6, 4, 0, 18, 0, 0, 14, 14, 114, 186, …
## $ `13`  <dbl> 0, 0, 119, 655, 20, 16, 0, 0, 6, 7, 0, 17, 0, 0, 224, 3, 679, 1,…
## $ `14`  <dbl> 0, 0, 53, 27, 220, 951, 0, 0, 11, 6, 0, 4, 0, 0, 8, 55, 42, 27, …
## $ `15`  <dbl> 0, 0, 2, 622, 922, 101, 0, 0, 403, 48, 0, 393, 0, 0, 266, 9, 6, …
## $ `16`  <dbl> 0, 0, 702, 457, 21, 203, 0, 0, 16, 39, 0, 1, 0, 0, 8, 128, 5, 6,…
## $ `17`  <dbl> 0, 0, 74, 4, 65, 21, 0, 0, 83, 17, 0, 17, 0, 0, 266, 0, 207, 137…
## $ `18`  <dbl> 0, 0, 179, 11, 11, 16, 0, 0, 14, 3, 0, 3, 0, 0, 359, 0, 63, 3, 0…
## $ `19`  <dbl> 0, 0, 0, 449, 10, 83, 0, 0, 3, 323, 0, 328, 0, 0, 964, 0, 75, 7,…
## $ `20`  <dbl> 0, 0, 0, 1, 3, 6, 0, 0, 58, 5, 0, 4, 0, 0, 16, 0, 511, 1, 0, 20,…
## $ `21`  <dbl> 0, 0, 0, 0, 69, 20, 0, 0, 237, 25, 0, 43, 0, 0, 8, 0, 1, 0, 0, 2…
## $ `22`  <dbl> 0, 0, 0, 0, 17, 21, 0, 0, 4, 16, 0, 18, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ `23`  <dbl> 0, 0, 0, 0, 16, 3, 0, 0, 1, 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 171, 0…
## $ `24`  <dbl> 0, 0, 0, 0, 11, 926, 0, 0, 0, 130, 0, 385, 0, 0, 0, 0, 0, 0, 0, …
## $ `25`  <dbl> 0, 0, 0, 0, 17, 20, 0, 0, 0, 12, 0, 28, 0, 0, 0, 0, 0, 0, 0, 7, …
## $ `26`  <dbl> 0, 0, 0, 0, 3, 118, 0, 0, 0, 5, 0, 50, 0, 0, 0, 0, 0, 0, 0, 217,…
## $ `27`  <dbl> 0, 0, 0, 0, 1, 6, 0, 0, 0, 315, 0, 168, 0, 0, 0, 0, 0, 0, 0, 144…
## $ `28`  <dbl> 0, 0, 0, 0, 0, 150, 0, 0, 0, 6, 0, 11, 0, 0, 0, 0, 0, 0, 0, 4, 0…
## $ `29`  <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 265, 0, 527, 0, 0, 0, 0, 0, 0, 0, 43…
## $ `30`  <dbl> 0, 0, 0, 0, 0, 891, 0, 0, 0, 6, 0, 21, 0, 0, 0, 0, 0, 0, 0, 11, …
## $ `31`  <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 292, 0…
## $ `32`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 315, …
## $ `33`  <dbl> 0, 0, 0, 0, 0, 186, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0…
## $ `34`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, …
## $ `35`  <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,…
## $ `36`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, …
## $ `37`  <dbl> 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53, 0,…
## $ `38`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, …
## $ `39`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 380, 0…
## $ `40`  <dbl> 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 668, 0,…
## $ `41`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0,…
## $ `42`  <dbl> 0, 0, 0, 0, 0, 926, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0…
## $ `43`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74, 0,…
## $ `44`  <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 43, 0…
## $ `45`  <dbl> 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0…
## $ `46`  <dbl> 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 83, 0…
## $ `47`  <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `48`  <dbl> 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `49`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `50`  <dbl> 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `51`  <dbl> 0, 0, 0, 0, 0, 213, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `52`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `53`  <dbl> 0, 0, 0, 0, 0, 186, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `54`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `55`  <dbl> 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `56`  <dbl> 0, 0, 0, 0, 0, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `57`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `58`  <dbl> 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `59`  <dbl> 0, 0, 0, 0, 0, 213, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `60`  <dbl> 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `61`  <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `62`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `63`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `64`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `65`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `66`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `67`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `68`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `69`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `70`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `71`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `72`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `73`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `74`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `75`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `76`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `77`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `78`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `79`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `80`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `81`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `82`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `83`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `84`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `85`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `86`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `87`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `88`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `89`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `90`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `91`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `92`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `93`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `94`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `95`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `96`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `97`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `98`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `99`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `100` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `101` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `102` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `103` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `104` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `105` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `106` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `107` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `108` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `109` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `110` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `111` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `112` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `113` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `114` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `115` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `116` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `117` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `118` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `119` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `120` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `121` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `122` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `123` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `124` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `125` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `126` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `127` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `128` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `129` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `130` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `131` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `132` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `133` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `134` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `135` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `136` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `137` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `138` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `139` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `140` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `141` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `142` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `143` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `144` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `145` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `146` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `147` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `148` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `149` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `150` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…

Once the full keras input is prepared I split it into training and testing parts.

set.seed(42) # Zaphod Beeblebrox advises to trust no other!

idx <- createDataPartition(keras_input$name, p = .8, list = F, times = 1) # 80 / 20 split

train_data <- keras_input[idx, ] # train dataset
test_data <- keras_input[-idx, ] # verification

With the training dataset ready a couple technical steps are necessary: converting the categorical field of author to a number, and conveting the whole input to a numerical matrix. The vocabulary size variable will be used to size keras input.

train_data <- train_data %>%
  mutate(hadley = ifelse(name == 'hadleywickham', 1,0)) %>% # binary output
  select(-id, -name, -text)

x_train <- data.matrix(train_data %>% select(-hadley)) # everything except target
y_train <- data.matrix(train_data %>% select(hadley)) # target, and target only

vocab_size <- vocabulary %>% # count unique word ids
  pull(id_slovo) %>% 
  unique() %>%
  length() + 1 # one extra for the zero padding

Once all the inputs are ready I define & train my model.

The model uses an embedding layer of the size of the vocabulary, and one bidirectional LSTM layer. As the LSTM’s are notoriously prone to overfitting a dropout layer is necessary.

The final layer is a single output node, representing the probability of a particular tweet being Hadley’s.

model <- keras_model_sequential() 

model %>% 
  layer_embedding(input_dim = vocab_size, output_dim = 256) %>%
  bidirectional(layer_lstm(units = 128)) %>%
  layer_dropout(rate = 0.5) %>% 
  layer_dense(units = 1, activation = 'sigmoid') # 1 = Hadley, 0 = Marie

model %>% 
  compile(optimizer = "rmsprop",
          loss = "binary_crossentropy",
          metrics = c("accuracy"))

history <- model %>%  # fit the model (this will take a while...)
  fit(x_train, 
      y_train, 
      epochs = 25, 
      batch_size = nrow(train_data)/5, 
      validation_split = 1/5)

summary(model)
## Model: "sequential"
## ________________________________________________________________________________
##  Layer (type)                       Output Shape                    Param #     
## ================================================================================
##  embedding (Embedding)              (None, None, 256)               253184      
##  bidirectional (Bidirectional)      (None, 256)                     394240      
##  dropout (Dropout)                  (None, 256)                     0           
##  dense (Dense)                      (None, 1)                       257         
## ================================================================================
## Total params: 647,681
## Trainable params: 647,681
## Non-trainable params: 0
## ________________________________________________________________________________

With the model trained I verify its effectiveness on the testing dataset, using the same pre–processing as for the training data.

pred_data <- test_data %>% # expected results
  select(id, name, text)

test_data <- test_data %>%
  mutate(hadley = ifelse(name == 'hadleywickham', 1,0)) %>% # binary output
  select(-id, -name, -text) # no cheating!

x_pred <- data.matrix(test_data %>% select(-hadley)) # keras needs matrix

pred <- model %>% # let keras sweat...
  predict(x_pred)

verifikace <- pred_data %>% # correct results ...
  cbind(pred) # ... joined with predictions

verifikace <- verifikace %>%
  mutate(name_pred = ifelse(pred > 0.5, 'hadleywickham', 'MarieKondo'))

conf_mtx <- table(verifikace$name, verifikace$name_pred)

print(paste0('Correctly predicted ',
             sum(diag(conf_mtx)), ' of ',
             sum(conf_mtx), ' tweets, which means ', 
             round(100 * sum(diag(conf_mtx))/sum(conf_mtx), 2), 
             '% of the total.'))
## [1] "Correctly predicted 180 of 199 tweets, which means 90.45% of the total."

Having accurary over 90% in a toy example such as this is not bad :)

It can be further improved in a number of ways - the two most obvious are including more training data (500 tweets is not that many), or adding more layers to create a more complicated model.