In [1]:
__author__ = "Sida Wang"
__version__ = "COS 495 NLP Spring 2018"

Neural network basics

The prediction of the multiclass linear classifier is based on the score $z_y = w_y \cdot \phi(x)$. While it is easy to learn $w$, the difficulties are hidden in $\phi(x)$, which is fixed after it is designed (by a person). One way to do better is for $\phi(x)$ to have learnable parameters as well. For example, by adding another linear function and another layer of featurization of $x$, like $w \cdot \phi(w' \cdot \phi'(x))$.

The network function (forward prop)

The two layer neural network does exactly that, by computing the vector-valued vector-input function $z = W_2 f(W_1 x + b_1) + b_2$. For arbitrary $f$, this says nothing at all. For a neural network, $f$ is a simple but non-linear elementwise function such as $f(x) = \max(0,x)$ or $f(x) = \frac1{1+\exp(-x)}$, which significantly restricts the space of functions considered. More layers can be added, for example, $z = W_3 f_3(W_2 f_2(W_1 x + b_1) + b_2)+b_3$. Generally, an $n$-layer feedforward neural network implements the network function $x \mapsto z$, where

$$ \begin{align} a_{1} & \gets x &\\ z & \gets z_{n} &\\ z_{i} & \gets W_i a_{i} + b_i \ \text{for} \ i = 1,\ldots,n\\ a_{i+1} & \gets f_{i+1}(z_{i}) \ \text{for} \ i = 1,\ldots,n-1.\\ \end{align} $$

The learnable parameters of the network are $\mathcal{W} = \{W_i, b_i\}_{i=1}^n$. Given data $x$, the desired prediction $y$, and the output $z$ computed from $x$, a loss function can be applied. For example, the least squares loss, the hinge loss (svm), and the "softmax loss" (i.e. the negative loglikelihood of the data under softmax) are, respectively,

$$ \begin{align} L_\text{ls}(\mathcal{W}, x, y) &= \left\lVert y-z \right\lVert_2^2,\\ L_\text{svm}(\mathcal{W}, x, y) &= \max(0, 1-(z_{y} - \max_{y'\neq y} z_{y'})),\\ L_\text{nll}(\mathcal{W}, x, y) &= -z_y + \exp(\sum_{y'} z_{y'}). \end{align} $$

Basic backprop

While there is no known guarantee, gradient descent has proven to work well for many tasks. Our task here is to find gradients of the loss with respect to all the parameters $\mathcal{W}$. Backprop is an efficient application of the chain rule starting from the gradient of the loss w.r.t the output $z$ and work backwards. For $i = n, n-1, \ldots, 2,$

$$ \begin{align} \frac{d L}{dz_{n}} &= \frac{d L}{dz}\\ \frac{d L}{da_{i}} &= \frac{d L}{dz_{i}} \frac{d z_{i}}{da_{i}} = W_i^T\frac{d L}{dz_{i}}\\ \frac{d L}{dz_{i-1}} &= \frac{d L}{da_{i}} \odot \frac{d a_{i}}{dz_{i-1}} = \frac{d L}{da_{i}} \odot f'_i(z_{i-1}) \\ \end{align} $$

Once these quantities are computed, the gradient with respect to the parameters is easy

$$ \begin{align} \frac{d L}{d W_n} &= \frac{d L}{dz_n} a_n^T,\\ \frac{d L}{d b_n} &= \frac{d L}{dz_n}. \end{align} $$

Exercise: Check the dimensions of $a_i, z_i, \frac{d L}{da_{i}}, \frac{d L}{dz_{i-1}}, W_i$ and make sure they are consistent in backprop.

Is this just chain rule?: It is often claimed that backprop is just the chain rule to trivialized the problem. While a lot of of written steps use the chain rule, there are a few counter points:

  • Chain rule does not in itself specify the order of application. A straight forward application of the basic scalar chain rule is to expand all the terms and then evaluate, which is too expensive. In backprop, good decisions are made on when expressions are evaluated as opposed to expanded, and what intermediate values need to be kept for efficient computation.

  • With suitable vector chain rule, we can get to something that looks like backprop quickly. However, the tensor chain rule says $$\frac{d L}{d W_i} = \frac{d L}{dz_i} \frac{d z_i}{d W_i},$$ where $\frac{d z_i}{d W_i}$ is a 3rd order tensor of size $\dim(z_i) \times \dim(W_i)=\dim(z_i)\cdot(\dim(z_i)\cdot\dim(a_{i}))$, and the product is a tensor-vector product that seems to require $\dim(z_i)\cdot\dim(z_i)\cdot\dim(a_{i})$ operations. So you would have to specify how this can be done without actually constructing $\frac{d z_i}{d W_i}$. Backprop requires $\dim(z_i)\cdot\dim(a_{i})$ operations. For just 1000 dimensions, that makes the difference between practical and impractical.

  • Under our formulation expressing the neural network as a function and our faith in SGD, computing its derivative is an obvious step. However, originally neural networks are developed to model learning by the brain, and it is a stretch that the brain learns by taking derivatives. From that perspective, backprop is not just this simple algorithm which computes the derivative, but also includes the empirical finding that the neural network models can often learn effectively by following the gradient.

It seems that a few more advanced ingredients than the chain rule is needed to get backprop.

2-layer neural network in python $(n=2)$

The code is no more complex than the math, in vectorized form. However, such code can be hard to debug. For example, it will run without error if I forgot to add the bias, or applied the wrong non-linear function. Gradient check is an essential tool for ensuring the correctness. While many issues and intentional noise does not stop SGD from working, systematically wrong gradient will make a difference.

In [2]:
import numpy as np
from numpy.random import randn
from copy import copy
# backpropagation code for least squares in a 2-layer neural network (single hidden layer)
# y_hat = W_2 f(W_1 x + b) + b_2, and L(y, y_hat) = (y - y_hat)^2

f2 = lambda x: np.maximum(0, x)
f2grad = lambda x: x > 0

# the loss function
lossfunc = lambda ypred, y: (ypred-y)*(ypred-y)
lossgrad = lambda ypred, y: 2*(ypred-y)

def fprop(x, y, params):
    W1, b1, W2, b2 = [params[key] for key in ('W1', 'b1', 'W2', 'b2')]
    z1 = np.dot(W1, x) + b1
    a2 = f2(z1)
    z2 = np.dot(W2, a2) + b2
    loss = lossfunc(z2, y)
    cache = {'x': x, 'y': y, 'z1': z1, 'a1': x, 'z2': z2, 'a2': a2, 'loss': loss}
    for key in params:
        cache[key] = params[key]
    return cache

def bprop(fprop_cache):
    x, y, z1, a1, z2, a2, loss = [fprop_cache[key] for key in ('x', 'y', 'z1', 'a1', 'z2', 'a2', 'loss')]
    dz2 = lossgrad(z2, y)
    dW2 = np.dot(dz2, a2.T)
    db2 = dz2
    da2 = np.dot(fprop_cache['W2'].T, dz2)
    dz1 = da2 * f2grad(z1)
    dW1 = np.dot(dz1, x.T)
    db1 = dz1
    return {'b1': db1, 'W1': dW1, 'b2': db2, 'W2': dW2}

Computation graph and modules

From a software engineering perspective, this implementation exhibits strong coupling between fprop and bprop. This coupling means that the human has to maintain the consistency between fprop and bprop when there should only be a dependence on the network architecture. This is unnecessary in principle and can quickly get out of hand for complicated networks. Here is an example backprop code for a network which is quite simple by modern standards. Imagine you made somes change to fprop such as rearranging some layers, renaming some parameters, add more layers etc., you would have to do the corresponding modifications for bprop while tracking the dependency structure in lines of code!

To understand the process more generally, consider a directed acyclic graph (DAG) $(V, E)$ that defines the function to be computed. Each node $v$ in the graph represents a function $\operatorname{in}(v) \mapsto \operatorname{out}(v)$, and each directed edge $(v, v')$ represents that $v'$ is a function of $v$ (among others). This means that (some of) the output of $v$ are (some of) the input of $v'$. In this case we say that $v'$ is a child of $v$, and we denote the set of all children of $v$ as $C(v) := \{v' \mid (v, v') \in E\}$ and the set of parents of $v$ as $P(v) := \{v' \mid (v', v) \in E\}$.

More precisely $\operatorname{in}(v) \subseteq \operatorname{out}(P(v))$ and $\operatorname{out}(v) \subseteq \operatorname{in}(C(v))$. Forward prop is evaluating each node from parents to children until the output node. For backprop to work, each node has to compute the gradient of its input given the gradient of of its output $$ \begin{align} \frac{dL}{d\operatorname{in}(v)} &= \frac{dL}{d\operatorname{out}(v)} \frac{d\operatorname{out}(v)}{d\operatorname{in}(v)}. \end{align} $$

Then the rest is up to chain rule. Information based on just the network structure $\frac{d \operatorname{in}(v')}{d\operatorname{out}(v)}$ and the gradients of the children of a node $\frac{dL}{d \operatorname{in}(v')}$ are required before the gradient of the node can be computed $$ \begin{align} \frac{dL}{d\operatorname{out}(v)} &= \frac{d}{d\operatorname{out}(v)} L(\operatorname{in}(C(v)), \operatorname{in}(V \backslash C(v))) \\ &= \sum_{v' \in C(v)} \frac{dL}{d \operatorname{in}(v')} \frac{d \operatorname{in}(v')}{d\operatorname{out}(v)}. \end{align} $$

Now there is some ecapsulation between the inner workings of each node and the overall structure, which means fprop and bprop of the whole graph can be decoupled. Let's call each of these nodes, which might have parameters, and which might themselves be a composition of a bunch of functions a module. Then we need some modules that are easy to differentiate by themselves, which can then be composed to form complex computation graphs. Here are some examples of simple modules:

$$ \begin{array}{ll} \text{linear:} & \operatorname{in}(v') = W \operatorname{out}(v) + b\\ \text{elementwise:} & \forall j: \operatorname{in}(v')_j = f(\operatorname{out}(v)_j)\\ \text{e.w. product:} & \operatorname{in}(v') = \operatorname{out}(v_1) \odot \operatorname{out}(v_2) \odot \operatorname{out}(v_3) \odot \ldots \\ \text{e.w. sum:} & \operatorname{in}(v') = \operatorname{out}(v_1) + \operatorname{out}(v_2) + \operatorname{out}(v_3) + \ldots\\ \end{array} $$

Given these, all the gradient computations can be done automatically by traversing the DAG from output to input in the any of the orderings implied by the DAG.

Minimal modules

More concretely, we would like to write torch style code like:

nn2layer = Sequential(OrderedDict([
            ('L1', Linear(params['W1'], params['b1'])),
            ('Relu1',    Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
            ('L2', Linear(params['W2'], params['b2']))
            ('Relu2',    Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
           ]))

Then you can swap out units, add more layers, add more structure without having to consider backprop. Here is my minimal implementation of torch style modules that has an explicit backward() in each module, pytorch did away with them but then the code will have to be much longer.

In [3]:
from collections import OrderedDict

class Module(object):
    def __init__(self):
        self.params = OrderedDict()
        self.grads = OrderedDict()
    def forward(self, *input):
        raise NotImplementedError
    # ideally, one infers backward from forward 
    def backward(self, *input, gradout):
        raise NotImplementedError
        
class Elementwise(Module):
    def __init__(self, f, dfdz):
        self.f, self.dfdz = f, dfdz
        
    def forward(self, input):
        self.input = input
        self.output = self.f(input)
        return self.output
    
    def backward(self, gradout):
        return gradout * self.dfdz(self.input)
        
class Linear(Module):
    def __init__(self, W, b):
        super(Linear, self).__init__()
        self.params['W'] = W
        self.params['b'] = b

    def forward(self, input):
        self.input = input
        W, b = self.params['W'], self.params['b']
        self.output = np.dot(W, input) + b
        return self.output
    
    def backward(self, gradout):
        self.grads['W'] = np.dot(gradout, self.input.T)
        self.grads['b'] = gradout
        return np.dot(self.params['W'].T, gradout)

class Sequential(Module):
    def __init__(self, children):
        self.children = children

    def forward(self, input):
        for child in self.children.values():
            input = child.forward(input)
        return input
    
    def backward(self, gradout):
        for child in reversed(self.children.values()):
            gradout = child.backward(gradout)
        return gradout

class Loss(Module):
    def __init__(self, netfunc, lossfunc, lossderiv):
        self.netfunc = netfunc
        self.lossfunc = lossfunc
        self.lossderiv = lossderiv
    
    def forward(self, input, target):
        self.target = target
        self.pred = self.netfunc.forward(input)
        return self.lossfunc(target, self.pred)
    
    def backward(self):
        deriv = self.lossderiv(self.target, self.pred)
        return self.netfunc.backward(deriv)

Network definition

With these modules, we can define the network.

In [4]:
params = {'L1.W': randn(50,100),
          'L1.b': randn(50,1),
          'L2.W': randn(1, 50),
          'L2.b': randn(1, 1)
         }

x = randn(100,1)
y = randn(1)*5

nn2layer = Sequential(OrderedDict([
            ('L1', Linear(params['L1.W'], params['L1.b'])),
            ('Relu',    Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
            ('L2', Linear(params['L2.W'], params['L2.b']))
           ]))

ls_loss = lambda y, yp: np.sum(np.square(y - yp))
ls_grad = lambda y, yp: 2*(yp - y)
final_loss = Loss(nn2layer, ls_loss, ls_grad)

loss = final_loss.forward(x, y)
final_loss.backward()

print('current loss', loss)
print('current pred/target', nn2layer.forward(x), y)
print('diff', (nn2layer.forward(x)-y)**2)
current loss 81.8046905742
current pred/target [[-12.02074689]] [-2.97615234]
diff [[ 81.80469057]]

Parameter management

One thing I ignored for the sake of really short code is parameter management. If I am to add a new module, I still have to add new parameters manually. It is not difficult to manage all the parameters in the modules as well.

In [5]:
def flatten(self):
    all_modules = OrderedDict()
    for key, module in self.children.items():
        all_modules[key] = module
        if 'children' in module.__dict__:
            # recurse on children
            raise NotImplementedError
    return all_modules
Sequential.flatten = flatten
    
def collect_params(root):
    params = {}
    grads = {}
    for name, module in root.flatten().items():
        if 'params' in module.__dict__:
            for key, param in module.params.items():
                params[name + '.' + key] = param
                grads[name + '.' + key] = module.grads[key]  
    return params, grads

params, grads = collect_params(nn2layer)

Tricks of the trade

Gradient check

Let us test the basic backprop code by comparing with the numerical gradient $$ \frac{d f(W,x,y)}{d W_{ij}} \approx \frac{f(W + \epsilon_{ij},x,y) - f(W - \epsilon_{ij},x,y)}{2 \epsilon}, $$ where $\epsilon_{ij}$ is the matrix of the same size as $W$, value $\epsilon$ at position $ij$ and 0 everywhere else. We will check both the basic implementation and the module.

In [6]:
def numerical_grad(fprop, x, y, params):
    eps = 1e-6
    ng_cache = {}
    # For every single parameter (W, b)
    for key in params:
        param = params[key]
        # This will be our numerical gradient
        ng = np.zeros(param.shape)
        for j in range(ng.shape[0]):
            for k in range(ng.shape[1]):
                # For every element of parameter matrix, compute gradient of loss wrt
                # that element numerically using finite differences
                add_eps = np.copy(param)
                min_eps = np.copy(param)
                add_eps[j, k] += eps
                min_eps[j, k] -= eps
                add_params = copy(params)
                min_params = copy(params)
                add_params[key] = add_eps
                min_params[key] = min_eps
                ng[j, k] = (np.sum(fprop(x, y, add_params)['loss']) \
                            - np.sum(fprop(x, y, min_params)['loss'])) / (2 * eps)
        ng_cache[key] = ng
    return ng_cache

def check_grad(params, grad1, grad2):
    # Compare numerical gradients to those computed using backpropagation algorithm
    for key in params:
        #print(bprop_grad[key])
        #print(num_grad[key])
        diff = grad1[key].flatten() - grad2[key].flatten()
        sums = grad1[key].flatten() + grad2[key].flatten()
        norm = np.max(np.abs(diff / sums)) 
        if norm < 1e-5:
            print(key, 'pass', norm)
        else:
            print(key, 'fail', norm)
In [7]:
import time
timeformat = '{0}: numdata: {1}\t time: {2:.5e}'
# test gradient check on the basic implementation
num_data, dim_data, num_hid = 1, 200, 300
W1 = np.random.rand(num_hid, dim_data)
b1 = np.random.rand(num_hid, 1)
W2 = np.random.rand(1, num_hid)
b2 = np.random.rand(1, 1)
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10

params = {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}

tic = time.time() ########## BEGIN
fprop_cache = fprop(x, y, params)
bprop_grad = bprop(fprop_cache)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('bprop', num_data, tictoc))


tic = time.time() ########## BEGIN
num_grad = numerical_grad(fprop, x, y, params)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('numerical', num_data, tictoc))

# Compare numerical gradients to those computed using backpropagation algorithm
check_grad(params, num_grad, bprop_grad)
print('loss', fprop_cache['loss'])

#############################################
# BEGIN CHECK MODULE

params = {'L1.W': W1, 'L1.b': b1, 'L2.W': W2, 'L2.b': b2}
nn2layer = Sequential(OrderedDict([
            ('L1', Linear(params['L1.W'], params['L1.b'])),
            ('Relu', Elementwise(lambda x: np.maximum(0, x), lambda x: x > 0)),
            ('L2', Linear(params['L2.W'], params['L2.b']))
           ]))

ls_loss = lambda y, yp: np.sum(np.square(y - yp))
ls_grad = lambda y, yp: 2*(yp - y)
final_loss = Loss(nn2layer, ls_loss, ls_grad)

params = {'L1.W': W1, 'L1.b': b1, 'L2.W': W2, 'L2.b': b2}


def fprop_wrapper(x, y, params):
    _children = nn2layer.children
    _children['L1'].params['W'] = params['L1.W']
    _children['L1'].params['b'] = params['L1.b']
    _children['L2'].params['W'] = params['L2.W']
    _children['L2'].params['b'] = params['L2.b']
    loss = final_loss.forward(x, y)
    return {'loss': loss}


tic = time.time() ########## BEGIN
loss_module = final_loss.forward(x, y)
final_loss.backward()
_children = nn2layer.children
module_grad = {'L1.W': _children['L1'].grads['W'], 'L1.b': _children['L1'].grads['b'],
              'L2.W': _children['L2'].grads['W'], 'L2.b': _children['L2'].grads['b']}
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('backprop-module', num_data, tictoc))

tic = time.time() ########## BEGIN
num_grad = numerical_grad(fprop_wrapper, x, y, params)
toc = time.time() ########## END
tictoc = toc - tic
print(timeformat.format('numerical-module', num_data, tictoc))

print('loss', loss_module)
check_grad(params, num_grad, module_grad)
bprop: numdata: 1	 time: 9.39131e-04
numerical: numdata: 1	 time: 6.25211e+00
W1 fail 0.0066986219126
b1 fail 3.11954283795e-05
W2 pass 8.07172270052e-09
b2 pass 6.43238059927e-08
loss [[ 53693618.33892669]]
backprop-module: numdata: 1	 time: 5.64337e-04
numerical-module: numdata: 1	 time: 6.97652e+00
loss 53693618.3389
L1.W fail 0.0066986219126
L1.b fail 3.11954283795e-05
L2.W pass 8.07172270052e-09
L2.b pass 6.43238059927e-08

Vectorized minibatch

Since matrix multiplications are highly optimized, it would be faster to do more of them. In fact, the code already works with minibatches of data points in the form of $X = [x_1, x_2, \ldots, x_n]$, where each data vector occupies a column. On my CPU, I got a speedup of 10 times for a fairly small network. Such speedups are expected to be more significant on GPUs.

num_data, dim_data, num_hid = 200, 100, 200
backprop-batch: numdata: 200     time: 8.87632e-04
backprop-loop: numdata: 200  time: 8.80098e-03
numerical: numdata: 200  time: 1.09668e+00

The last row is the time it takes to compute all the gradients numerically, so do not get any ideas.

Exercise: there is one issue in bprop that prevents vectorized minibatching from being completely correct. Spot it and propose a fix.

In [8]:
num_data, dim_data, num_hid = 100, 200, 300
W1 = np.random.rand(num_hid, dim_data)
b1 = np.random.rand(num_hid, 1)
W2 = np.random.rand(1, num_hid)
b2 = np.random.rand(1, 1)
params = {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10

tic = time.time()
x = np.random.rand(dim_data, num_data)
y = np.random.rand(1, num_data) * 10
fprop_cache = fprop(x, y, params)
bprop_grad = bprop(fprop_cache)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('backprop-batch', num_data, tictoc))

tic = time.time()
for i in range(num_data):
    x = np.random.rand(dim_data, 1)
    y = np.random.rand(1, 1) * 10
    fprop_cache = fprop(x, y, params)
    bprop_grad = bprop(fprop_cache)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('backprop-loop', num_data, tictoc))
backprop-batch: numdata: 100	 time: 1.59931e-03
backprop-loop: numdata: 100	 time: 6.58536e-03
In [9]:
tic = time.time()
num_grad = numerical_grad(fprop, x, y, params)
toc = time.time()
tictoc = toc - tic
print(timeformat.format('numerical', num_data, tictoc))
numerical: numdata: 100	 time: 5.74471e+00

Dropout

Initialization

Optimization

Batch normalization

Appendix: Bap to the dark ages

If the basic coupled implementation seemed easy enough to work with, here is the backprop code for a slightly more complex network (but very simple by modern standard). The fprop and a few function called in boths parts are here.

# highly coupled and unreadable code from the transforming autoencoder, around 2011
  def backprop(self, target, input, trsf):
    numcases = target.shape[1]
    biasfac = 1

    self.calcoutput(input, trsf)
    self.diff = self.output - target
    self.wu_ho = g.dot(self.diff, self.h2.T * self.pr.T) / numcases
    self.wu_o = self.diff.sum(1)[:,None] / (biasfac * numcases)

    dEdH2 = g.dot(self.w_ho.T, self.diff) * self.pr
    dEdH2in = dEdH2 * (self.h2 - self.shift) * (1- self.h2 + self.shift)
    self.wu_ch = g.dot(dEdH2in, self.ct.T) / numcases
    self.wu_h2 = dEdH2in.sum(1)[:,None] / (biasfac * numcases)

    dEdCin = g.dot(self.w_ch.T, dEdH2in)
    self.bpdEdCin = g.garray(dEdCin)
    #print 'max before invert %f %f' % (dEdCin.max(), dEdCin.min())
    if not self.justtranslate: dEdCin = g.garray(self.applymatf(trsf,dEdCin,transpose=True))

    dEdPout = (self.w_ho[:,:,None] * self.h2[None,:,:] * self.diff[:,None,:])\
        .sum(0).reshape(self.groupcoord, self.sizehid2, numcases)\
        .sum(1)
    dEdPin = self.p * (1-self.p) * dEdPout 

    self.wu_hp = g.dot(dEdPin, self.h1.T) / numcases
    self.wu_p = dEdPin.sum(1)[:,None] / (biasfac * numcases)

    #dEdCin *= (self.maskm * self.pr)
    self.bpdEdCinm = g.garray(dEdCin)
    #print 'max after  invert %f %f' % (dEdCin.max(), dEdCin.min())
    self.wu_hc = g.dot(dEdCin, self.h1.T) / numcases
    self.wu_c = dEdCin.sum(1)[:,None] / (biasfac * numcases)

    dEdH1 = g.dot(self.w_hc.T, dEdCin) + g.dot(self.w_hp.T, dEdPin)
    dEdH1in = dEdH1 * (self.h1 - self.shift) * (1 - self.h1 + self.shift)
    self.wu_vh = g.dot(dEdH1in, input.T) / numcases
    self.wu_h1 = dEdH1in.sum(1)[:,None] / (biasfac * numcases)