Backpropagation vectorization hints

From Ufldl

Jump to: navigation, search

Here, we give a few hints on how to vectorize the Backpropagation step. The hints here specifically build on our earlier description of how to vectorize a neural network.

Assume we have already implemented the vectorized forward propagation steps, so that the matrix-valued z2, a2, z3 and h have already been computed. Here was our unvectorized implementation of backprop:

gradW1 = zeros(size(W1));
gradW2 = zeros(size(W2)); 
for i=1:m,
  delta3 = -(y(:,i) - h(:,i)) .* fprime(z3(:,i)); 
  delta2 = W2'*delta3(:,i) .* fprime(z2(:,i));
  gradW2 = gradW2 + delta3*a2(:,i)';
  gradW1 = gradW1 + delta2*a1(:,i)'; 

Assume that we have implemented a version of fprime(z) that accepts matrix-valued inputs. We will use matrix-valued delta3, delta2. Here, delta3 and delta2 will have m columns, with one column per training example. We want to compute delta3, delta2, gradW2 and gradW1.

Consider the computation for the matrix delta3, which can now be written:

for i=1:m, 
  delta3(:,i) = -(y(:,i) - h(:,i)) .* fprime(z3(:,i)); 

Each iteration of the for loop computes one column of delta3. You should be able to find a single line of Matlab to compute delta3 as a function of the matrices y, h and z3. This lets you compute delta3. Similarly, you should also be able to find a single line of code to compute the entire matrix delta2, as a function of W2, delta3 (which is now a matrix), and z2.

Next, consider the computation for gradW2. We can now write this as:

gradW2 = zeros(size(W2));
for i=1:m, 
  gradW2 = gradW2 + delta3(:,i)*a2(:,i)';

You should be able to find a single line of Matlab that replaces this for loop, and computes gradW2 as a function of the matrices delta3 and a2. If you're having trouble, take another look at the Logistic Regression Vectorization Example, which uses a related (but slightly different) vectorization step to get to the final implementation. Using a similar method, you will also be able to compute gradW1 with a single line of code.

When you complete the derivation, you should be able to replace the unvectorized backpropagation code example above with just 4 lines of Matlab/Octave code.

Personal tools