# this implements dp mixture sampling for 2d gaussians # example.use <- function() { # requires that you have sourced dp-simulate.R ## sample data data <- sample.from.crp(1, gauss.base, gauss.gen, n=250) x <- as.matrix(ldply(data$data)) ## plot the original data plot.mixture(ldply(data$param), data$table, x, plot.centers=F) ## run 10 iterations of gibbs sampling post <- dp.gibbs(x, 1, c(0, 0), 10000, 1, niter=10) ## plot the decomposition plot.crp.post(post, x) } dp.gibbs <- function(data, alpha, prior.mean, prior.var, data.var, niter) { n <- dim(data)[1] score <- c() score.const <- sum(log(1:n - 1 + alpha)) # initialize the values #1. this leads to an initialization of the # sampler where each table is iteratively sampled conditioned on the # previous and the previous data. after that, we treat each data # point as though it were the last. tables <- rep(0,n) ss <- list() counts <- c() ## initialize the values #2. each data point is at its own table. ## the sampler initially merges tables together to get to a good ## place in the posterior. ## tables <- 1:n ## ss <- list() ## for (i in 1:n) ## { ## ss[[i]] <- list(sum=data[i,], sumsq=data[i,] %*% data[i,], count=1) ## ss[[i]]$lhood <- post.gauss(ss[[i]]$sum, ss[[i]]$sumsq, ## ss[[i]]$count, prior.mean, ## prior.var, data.var)$lhood ## } ## counts <- rep(1, n) for (iter in 1:niter) { for (i in 1:n) { # get table assignment and data point z <- tables[i] x <- data[i,] # remove this data point from its table if (z > 0) { ss[[z]]$sum <- ss[[z]]$sum - x ss[[z]]$sumsq <- ss[[z]]$sumsq - (x%*%x) ss[[z]]$count <- ss[[z]]$count - 1 ss[[z]]$lhood <- post.gauss(ss[[z]]$sum, ss[[z]]$sumsq, ss[[z]]$count, prior.mean, prior.var, data.var)$lhood counts[z] <- counts[z] - 1 # if this data point was the only member of its table if (counts[z] == 0) { # remove the table from the restaurant configuration counts <- counts[-z] ss <- ss[-z] # renumber the tables "beyond" this table tables[tables > z] <- tables[tables > z] - 1 } } # sample a new table for this data point. if no tables are # assigned (i.e., counts == null) then sit down at the first # table. (this only occurs during initialization.) if (is.null(counts)) z <- 1 else z <- sample.table(x, ss, counts, alpha, prior.mean, prior.var, data.var) # add this data point to its sampled table tables[i] <- z # if we need to make a new table if (z > length(counts)) { # check that the table is the next table stopifnot(z==length(counts) + 1) # set up the new counts and sufficient statistics counts <- c(counts, 0) ss[[z]] <- list(sum=c(0,0), sumsq=0, lhood=0, count = 0) } counts[z] <- counts[z] + 1 ss[[z]]$sum <- ss[[z]]$sum + x ss[[z]]$sumsq <- ss[[z]]$sumsq + x %*% x ss[[z]]$count <- ss[[z]]$count + 1 ss[[z]]$lhood <- post.gauss(ss[[z]]$sum, ss[[z]]$sumsq, ss[[z]]$count, prior.mean, prior.var, data.var)$lhood } # compute the score my.score <- log.partition.prob(counts, alpha)+ sum(laply(ss,function (x) x$lhood))- score.const score <- c(score, my.score) msg(sprintf("iteration = %d; number of components = %d, score = %g", iter, length(counts), score[iter])) } list(counts=counts, ss=ss, tables=tables, score=score) } # note: this is only the numerator of equation 5 in the notes. the # denominator is constant across iterations. no need to recompute. log.partition.prob <- function(counts, alpha) { length(counts)*log(alpha) + sum(lgamma(counts)) } # sample the table for a data point # # sum = sum of data # sumsq = sum of squares of data # n = number of data # m0 = prior mean # t0 = prior variance # s = data generating variance post.gauss <- function(sum, sumsq, n, m0, t0, s) { var <- 1 / (1/t0 + n/s) mean <- ((m0/t0) + (sum/s)) * var lhood <- - sumsq / s lhood <- lhood + (mean %*% mean) / var + log(var) lhood <- lhood - (m0 %*% m0) / t0 - log(t0) lhood <- (length(sum) / 2) * lhood list(mean=mean, var=var, lhood=lhood) } sample.table <- function(point, suff.stats, counts, alpha, prior.mean, prior.var, data.var) { # get the dimension and number of data n <- sum(counts) p <- length(point) point.sq <- point %*% point # compute the prior log.denom <- log(n + alpha) log.prior <- c(log(counts) - log.denom, log(alpha) - log.denom) # compute the integrated likelihood for each cluster ll <- laply(suff.stats, function (ss) { post.gauss(ss$sum + point, ss$sumsq + point.sq, ss$count + 1, prior.mean, prior.var, data.var)$lhood - ss$lhood }) # compute the integrated likelihood for the new cluster new.lhood <- post.gauss(point, point.sq, 1, prior.mean, prior.var, data.var)$lhood # combine to form a distribution prob <- log.prior + c(ll, new.lhood) prob <- prob - log.sum(prob) prob <- exp(prob) # sample the table sample(1:length(prob), 1, prob=prob) } # --- helper log.sum <- function(v) { # step 1: subtract log(x) to log(x) and log(y) # step 2: exponentiate #1 and take the log # step 3: add back log(x) log.sum.pair <- function(x,y) { if ((y == -Inf) && (x == -Inf)) { return(-Inf); } if (y < x) return(x+log(1 + exp(y-x))) else return(y+log(1 + exp(x-y))); } if (length(v) == 1) return(v) r <- v[1]; for (i in 2:length(v)) r <- log.sum.pair(r, v[i]) return(r) } # --- plots plot.sample <- function(s, obs) { locs <- ldply(s$suff.stats, function (x) x$sum / x$count) plot.mixture(locs, s$tables, obs) } plot.mixture <- function(locs, z, obs, plot.centers=T) { stopifnot(dim(obs)[2]==2) z <- as.factor(z) df1 <- data.frame(x=obs[,1], y=obs[,2], z=z) df2 <- data.frame(x=locs[,1], y=locs[,2]) p <- ggplot() p <- p + geom_point(data=df1, aes(x=x, y=y, colour=z), shape=16, size=2, alpha=0.5) if (plot.centers) p <- p + geom_point(data=df2, aes(x=x, y=y), shape=16, size=1) p <- p + theme_bw() + opts(legend.position="none") p } plot.crp.post <- function(post, x) { means <- ldply(post$ss, function (x) post.gauss(x$sum, x$sumsq, x$count, c(0,0), 100, 1)$mean) plot.mixture(means,post$tables, x) } ### alternative initialization: each data at its own table ## tables <- 1:n ## ss <- list() ## for (i in 1:n) ## { ## ss[[i]] <- list(sum=data[i,], sumsq=data[i,] %*% data[i,], count=1) ## } ## counts <- rep(1, n)