LDA inference

The goal of LDA inference is to infer the probability of a topic for any given word in a given document. This work can be done by referencing the frequency of other words assigned to this topic. That is,

\[\begin{align} P(z_i|z_{\neg,i},w,\alpha,\beta)&=\frac{P(z_i,z_{\neg,i},w|\alpha,\beta)}{P(z_{\neg,i},w|\alpha,\beta)}\\ &\propto P(z_i,z_{\neg,i},w|\alpha,\beta)\\ &\propto P(z,w|\alpha,\beta). \end{align}\]

Also,

\[\begin{align} P(z,w|\alpha,\beta)&=\int\int P(z|\theta)P(\theta|\alpha)P(w|\phi_z)P(\phi|\beta)d\theta d\phi\\ &=\int P(z|\theta)P(\theta|\alpha)d\theta \int P(w|\phi_z)P(\phi|\beta)d\phi. \end{align}\]

Apparently, the right part of this equation can be separated to two parts, respectively corresponding to the marginal likelihood of the LDA model to generate the topics and the marginal likelihood of the model to generate all words. These two parts can be solved as follows.

\[\begin{align} \int P(z|\theta)P(\theta|\alpha)d\theta&=\int \prod_i \theta_{d_i,z_i} \frac{1}{B(\alpha)}\prod_k \theta_{d,k}^{\alpha k} d\theta\\ &=\frac{1}{B(\alpha)}\int \prod_k \theta_{d,k}^{n_{d,k}+\alpha k}\\ &=\frac{B(n_{d,.}+\alpha)}{B(\alpha)}. \end{align}\]

Similarly, the marginal likelihood to generate words can be derived as

\[\begin{align} \int P(w|\phi_z)P(\phi|\beta)d\phi&=\int \prod_d \prod_i \phi_{z_{d,i},w_{d,i}}\prod_k \frac{1}{B(\beta)}\prod_w \phi_{k,w}^{B_w}d\phi_k\\ &=\prod_k\frac{1}{B(\beta)}\int \prod_w \phi_{k,w}^{B_w+n_{k,w}} d\phi_k\\ &=\prod_k\frac{B(n_{k,.}+\beta+)}{B(\beta)}. \end{align}\]

Therefore,

\[\begin{align} P(z,w|\alpha,\beta)=\prod_d \frac{B(n_{d,.}+\alpha)}{B(\alpha)}\prod_k\frac{B(n_{k,.}+\beta)}{B(\beta)}, \end{align}\] where \(n_{d,.}\) is the sum of the number of words assigned to each topic in the current document \(d\), whereas \(n_{k,.}\) is the sum of the number of words assigned to each topic across all documents.

As \[\begin{align} P(z_i|z_{\neg,i},w)&=\frac{P(z,w)}{P(z_{\neg,i},w)}=\frac{P(z_i)}{P(z{\neg,i})}\frac{P(w|z)}{P(w_{\neg,i}|P(z_{\neg,i}))P(w_i)}\\ &\propto \prod_d\frac{B(n_{d_{,.}}+\alpha)}{B(n_{d_{,\neg i}}+\alpha)}\frac{B(n_{k_{,.}}+\beta)}{B(n_{k_{,\neg i}}+\beta)}\\ &\propto (n_{d,\neg i}^k+\alpha_k)\frac{n_{k,\neg i}^w+\beta_w}{\sum n_{k,\neg i}^w+\beta_w}. \end{align}\]

Since

\[\begin{align} \phi_{k,w}&=\frac{n_k^w+\beta_w}{\sum n_k^w+\beta_w}\\ \theta_{d,k}&=\frac{n_d^k+\alpha_k}{\sum n_d^k+\beta_k}. \end{align}\]

Therefore, the posterior probability of each topic is proportional to \(\theta_{d,k}\phi_{k,w}\).

We will use Gibbs sampling to do LDA inference. The pseudocode for implementing Gibbs sampling is as follow.

The example of Gibbs sampling algorithm

# LDA for inference
library(dirmult)
library(MCMCpack)
## Loading required package: coda
## Loading required package: MASS
## ##
## ## Markov Chain Monte Carlo Package (MCMCpack)
## ## Copyright (C) 2003-2022 Andrew D. Martin, Kevin M. Quinn, and Jong Hee Park
## ##
## ## Support provided by the U.S. National Science Foundation
## ## (Grants SES-0350646 and SES-0350613)
## ##
## 
## Attaching package: 'MCMCpack'
## The following object is masked from 'package:dirmult':
## 
##     rdirichlet
library(tidyverse)
## ── Attaching packages
## ───────────────────────────────────────
## tidyverse 1.3.2 ──
## ✔ ggplot2 3.3.6     ✔ purrr   0.3.4
## ✔ tibble  3.1.7     ✔ dplyr   1.0.9
## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
## ✔ readr   2.1.2     ✔ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ✖ dplyr::select() masks MASS::select()
library(Rcpp)
library(lsa)
## Loading required package: SnowballC
# 
get_topic<-function(k)which(rmultinom(1,1,rep(1/k,k))==1)
get_word<-function(theta,phi){
  topic<-which(rmultinom(1,1,theta)==1)
  new_word<-which(rmultinom(1,1,phi[topic,])==1)
  return(c(new_word,topic))
}
cppFunction(
'List gibbsLda(NumericVector topic, NumericVector doc_id, NumericVector word,
               NumericMatrix n_doc_topic_count,NumericMatrix n_topic_term_count,
               NumericVector n_topic_sum, NumericVector n_doc_word_count){
  int alpha=1;
  int beta=1;
  int cs_topic,cs_doc,cs_word,new_topic;
  int n_topics=max(topic)+1;
  int vocab_length=n_topic_term_count.ncol();
  double p_sum=0,num_doc,denom_doc,denom_term,num_term;
  NumericVector p_new(n_topics);
  IntegerVector topic_sample(n_topics);
  
  for(int iter=0;iter<100;iter++){
     for(int j=0;j<word.size();++j){
        // Change value outside of cuntion to prevent confusion
        cs_topic=topic[j];
        cs_doc=doc_id[j];
        cs_word=word[j];
        // Decrement counts
        n_doc_topic_count(cs_doc,cs_topic)=n_doc_topic_count(cs_doc,cs_topic)-1;
        n_topic_term_count(cs_topic,cs_word)=n_topic_term_count(cs_topic,cs_word)-1;
        n_topic_sum[cs_topic]=n_topic_sum[cs_topic]-1;
        // Get probability for each topic, select topic with highest prob
        for(int tpc=0;tpc<n_topics;tpc++){
           // word cs_word topic tpc + beta
           num_term=n_topic_term_count(tpc,cs_word)+beta;
           // sum of all word counts w/ topic tpc + vocab length*beta
           denom_term=n_topic_sum[tpc]+vocab_length*beta;
           // count of topic tpc in cs_doc + alpha
           num_doc=n_doc_topic_count(cs_doc,tpc)+alpha;
           // total word count in cs_doc + n_topics*alpha
           denom_doc=n_doc_word_count[cs_doc]+n_topics*alpha;
           p_new[tpc]=(num_term/denom_term)*(num_doc/denom_doc);
        }
           // normalize the posteriors
        p_sum=std::accumulate(p_new.begin(),p_new.end(),0.0);
        for(int tpc=0;tpc<n_topics;tpc++){
          p_new[tpc]=p_new[tpc]/p_sum;
        }
        // sample new topic based on posterior distribution
        R::rmultinom(1, p_new.begin(), n_topics, topic_sample.begin());
        
        for(int tpc=0;tpc<n_topics;tpc++){
          if(topic_sample[tpc]==1){
              new_topic=tpc;
          }
        }
        // print(new_topic)
        // update counts
        n_doc_topic_count(cs_doc,new_topic)=n_doc_topic_count(cs_doc,new_topic)+1;
        n_topic_term_count(new_topic,cs_word)=n_topic_term_count(new_topic,cs_word)+1;
        n_topic_sum[new_topic]=n_topic_sum[new_topic]+1;
        // update current state
        topic[j]=new_topic;
     }
  }
return List::create(n_topic_term_count,n_doc_topic_count);

}')

Let’s prepare documents and topics. Suppose there are three topics, respectively denoting the land, sea, and air animals. See the below example.

beta<-1
k<-3 # topic number
M<-100 # document number
alphas<-rep(1,k)
xi<-100 # average document length
N<-rpois(M,xi) # words in each document

# Create animal labels

# whale1, whale2, FISH1, FISH2,OCTO
sea_animals<-c('\U1F40B', '\U1F433','\U1F41F', '\U1F420', '\U1F419')
# crab, alligator, TURTLE,SNAKE
amphibious<-c('\U1F980', '\U1F40A', '\U1F422', '\U1F40D')
# CHICKEN, TURKEY, DUCK, PENGUIN
birds<-c('\U1F413','\U1F983','\U1F426','\U1F427')
# SQUIRREL, ELEPHANT, COW, RAM, CAMEL
land_animals<-c('\U1F43F','\U1F418','\U1F402','\U1F411','\U1F42A')
vocab<-c(sea_animals, amphibious, birds, land_animals) # all animal labels
# equal probability 1/18
# 0 - animals that are not possible
# 1 - for shared
# 4 - non-shared
shared <- 2
non_shared <- 4
not_present <- 0
land_phi <- c(rep(not_present, length(sea_animals)),
              rep(shared, length(amphibious)),
              rep(non_shared, 2), # turkey and chicken can't fly
              rep(shared, 2), # regular bird and pengiun
              rep(non_shared, length(land_animals)))
land_phi <- land_phi/sum(land_phi)
sea_phi <- c(rep(non_shared, length(sea_animals)),
             rep(shared, length(amphibious)),
             rep(not_present, 2), # turkey and chicken can't fly 
             rep(shared, 2), # regular bird and pengiun 
             rep(not_present, length(land_animals)))
sea_phi <- sea_phi/sum(sea_phi)
air_phi <- c(rep(not_present, length(sea_animals)),
             rep(not_present, length(amphibious)),
             rep(not_present, 2), # turkey and chicken can't fly 
             non_shared, # regular bird
             not_present, # penguins can't fly
             rep(not_present, length(land_animals)))
air_phi <- air_phi/sum(air_phi)

Now we have created three topic vectors. We can create test data.

# calculate topic word distributions
phi <- matrix(c(land_phi, sea_phi, air_phi), nrow = k, ncol = length(vocab), 
              byrow = TRUE, dimnames = list(c('land', 'sea', 'air')))
phi
##       [,1]  [,2]  [,3]  [,4]  [,5]   [,6]   [,7]   [,8]   [,9] [,10] [,11]
## land 0.000 0.000 0.000 0.000 0.000 0.0500 0.0500 0.0500 0.0500   0.1   0.1
## sea  0.125 0.125 0.125 0.125 0.125 0.0625 0.0625 0.0625 0.0625   0.0   0.0
## air  0.000 0.000 0.000 0.000 0.000 0.0000 0.0000 0.0000 0.0000   0.0   0.0
##       [,12]  [,13] [,14] [,15] [,16] [,17] [,18]
## land 0.0500 0.0500   0.1   0.1   0.1   0.1   0.1
## sea  0.0625 0.0625   0.0   0.0   0.0   0.0   0.0
## air  1.0000 0.0000   0.0   0.0   0.0   0.0   0.0
theta_samples<-rdirichlet(M, alphas)
thetas<-theta_samples[rep(1:nrow(theta_samples),times=N), ]
new_words<-t(apply(thetas,1,function(x) get_word(x,phi)))

ds <-tibble(doc_id = rep(1:length(N), times = N), 
            word   = new_words[,1],
            topic  = new_words[,2], 
            theta_a = thetas[,1],
            theta_b = thetas[,2],
            theta_c = thetas[,3]
) 
ds
## # A tibble: 9,844 × 6
##    doc_id  word topic theta_a theta_b theta_c
##     <int> <int> <int>   <dbl>   <dbl>   <dbl>
##  1      1    12     3  0.0759  0.0725   0.852
##  2      1    12     3  0.0759  0.0725   0.852
##  3      1    12     3  0.0759  0.0725   0.852
##  4      1    12     3  0.0759  0.0725   0.852
##  5      1     2     2  0.0759  0.0725   0.852
##  6      1    12     3  0.0759  0.0725   0.852
##  7      1    12     3  0.0759  0.0725   0.852
##  8      1    12     3  0.0759  0.0725   0.852
##  9      1    12     3  0.0759  0.0725   0.852
## 10      1    12     3  0.0759  0.0725   0.852
## # … with 9,834 more rows
# We only select two documents for testing
ds1<-ds %>% filter(doc_id < 3) %>% group_by(doc_id) %>% summarise(
  tokens = paste(vocab[word], collapse = ' ')
)
colnames(ds1)<-c("Document","Animals")
tt<-ds %>% filter(doc_id==1) %>% with(table(word))
c(sum(tt[1:5]),sum(tt[6:9]),sum(tt[10:11]),0)/sum(tt)
## [1] 0.07291667 0.07291667 0.85416667 0.00000000
ds1
## # A tibble: 2 × 2
##   Document Animals                                                              
##      <int> <chr>                                                                
## 1        1 🐦 🐦 🐦 🐦 🐳 🐦 🐦 🐦 🐦 🐦 🐦 🐦 🐦 🐓 🐦 🐦 🐦 🐦 🐦 🐦 🐦 🦃 🐦…
## 2        2 🐦 🐦 🐦 🐦 🐦 🐦 🐦 🐦 🐘 🐟 🐦 🐦 🐊 🐦 🐦 🐍 🐦 🐦 🐦 🐦 🐦 🐦 🐦…

Finally, we can do LDA inference. See the below codes.

current_state <- ds %>% dplyr::select(doc_id, word, topic)
current_state$topic<-NA
t<-length(unique(current_state$word))

# n_doc_topic_count
n_doc_topic_count<-matrix(0,M,k)
# document_topic_sum
n_doc_topic_sum<-rep(0,M)
# topic_term_count
n_topic_term_count<-matrix(0,k,t)
# topic_term_sum
n_topic_sum<-rep(0,k)
p<-rep(0,k)

# initialize topics
current_state$topic<-replicate(nrow(current_state),get_topic(k))
# get word, topic, and document counts (used during inference process)

n_doc_topic_count<-current_state %>% group_by(doc_id,topic) %>%
  summarise(count=n()) %>% spread(key=topic,value=count) %>% as.matrix()
## `summarise()` has grouped output by 'doc_id'. You can override using the
## `.groups` argument.
n_topic_sum<-current_state %>% group_by(topic) %>% 
  summarise(count=n()) %>% select(count) %>% as.matrix() %>% as.vector()

n_topic_term_count<-current_state %>% group_by(topic,word) %>%
  summarise(count=n()) %>% spread(word,count) %>% as.matrix()
## `summarise()` has grouped output by 'topic'. You can override using the
## `.groups` argument.
# minus 1 in, adds 1 out
lda_counts <- gibbsLda( current_state$topic-1 , current_state$doc_id-1, current_state$word-1,
                        n_doc_topic_count[,-1], n_topic_term_count[,-1], n_topic_sum, N)
# calculate estimates for phi and theta
# phi - row apply to lda_counts[[1]]
# rewrite this function and normalize by row so that they sum to 1
phi_est<-apply(lda_counts[[1]],1,function(x)(x+beta)/(sum(x)+length(vocab)*beta))
rownames(phi_est)<-vocab
colnames(phi)<-vocab
theta_est<-apply(lda_counts[[2]],2,function(x)(x+alphas[1])/(sum(x)+k*alphas[1]))
theta_est<-t(apply(theta_est,1,function(x)x/sum(x)))
colnames(theta_samples)<-c('land','sea','air')
vector_angles<-cosine(cbind(theta_samples,theta_est))[4:6,1:3]
estimated_topic_names<-apply(vector_angles, 1, function(x)colnames(vector_angles)[which.max(x)])
phi_table<-as_tibble(t(round(phi,2))[,estimated_topic_names])
phi_table<-cbind(phi_table,as_tibble(round(phi_est,2)))
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
names(phi_table)[4:6]<-paste0(estimated_topic_names," estimated")
theta_table<-cbind(theta_samples,theta_est)
colnames(theta_table)[4:6]<-paste0(estimated_topic_names," estimated")
phi_table
##     sea air land sea estimated air estimated land estimated
## 1  0.12   0 0.00          0.13          0.00           0.00
## 2  0.12   0 0.00          0.11          0.01           0.00
## 3  0.12   0 0.00          0.12          0.00           0.00
## 4  0.12   0 0.00          0.11          0.01           0.00
## 5  0.12   0 0.00          0.12          0.00           0.01
## 6  0.06   0 0.05          0.06          0.00           0.05
## 7  0.06   0 0.05          0.06          0.01           0.04
## 8  0.06   0 0.05          0.06          0.00           0.05
## 9  0.06   0 0.05          0.05          0.01           0.05
## 10 0.00   0 0.10          0.01          0.00           0.10
## 11 0.00   0 0.10          0.00          0.00           0.11
## 12 0.06   1 0.05          0.08          0.94           0.04
## 13 0.06   0 0.05          0.05          0.00           0.05
## 14 0.00   0 0.10          0.00          0.01           0.09
## 15 0.00   0 0.10          0.00          0.01           0.10
## 16 0.00   0 0.10          0.01          0.00           0.10
## 17 0.00   0 0.10          0.01          0.00           0.09
## 18 0.00   0 0.10          0.00          0.00           0.11
head(theta_table)
##           land        sea       air sea estimated air estimated land estimated
## [1,] 0.0759474 0.07246955 0.8515830    0.08318274     0.8223076     0.09450971
## [2,] 0.2396748 0.11424177 0.6460835    0.06767281     0.8093067     0.12302048
## [3,] 0.1362549 0.41844457 0.4453005    0.49611813     0.4339688     0.06991310
## [4,] 0.2186711 0.23045804 0.5508709    0.30135352     0.5906529     0.10799358
## [5,] 0.1635633 0.64812180 0.1883149    0.60983936     0.2156573     0.17450337
## [6,] 0.5109948 0.38828438 0.1007208    0.31737731     0.1484086     0.53421406