Backpropagation vectorization hints
Here, we give a few hints on how to vectorize the Backpropagation step. The hints here specifically build on our earlier description of how to vectorize a neural network.
Assume we have already implemented the vectorized forward propagation steps, so that the matrix-valued z2, a2, z3 and h have already been computed. Here was our unvectorized implementation of backprop:
gradW1 = zeros(size(W1)); gradW2 = zeros(size(W2)); for i=1:m, delta3 = -(y(:,i) - h(:,i)) .* fprime(z3(:,i)); delta2 = W2'*delta3(:,i) .* fprime(z2(:,i)); gradW2 = gradW2 + delta3*a2(:,i)'; gradW1 = gradW1 + delta2*a1(:,i)'; end;
Assume that we have implemented a version of fprime(z) that accepts matrix-valued inputs. We will use matrix-valued delta3, delta2. Here, delta3 and delta2 will have m columns, with one column per training example. We want to compute delta3, delta2, gradW2 and gradW1.
Consider the computation for the matrix delta3, which can now be written:
for i=1:m, delta3(:,i) = -(y(:,i) - h(:,i)) .* fprime(z3(:,i)); end;
Each iteration of the for loop computes one column of delta3. You should be able to find a single line of Matlab to compute delta3 as a function of the matrices y, h and z3. This lets you compute delta3. Similarly, you should also be able to find a single line of code to compute the entire matrix delta2, as a function of W2, delta3 (which is now a matrix), and z2.
Next, consider the computation for gradW2. We can now write this as:
gradW2 = zeros(size(W2)); for i=1:m, gradW2 = gradW2 + delta3(:,i)*a2(:,i)'; end;
You should be able to find a single line of Matlab that replaces this for loop, and computes gradW2 as a function of the matrices delta3 and a2. If you're having trouble, take another look at the Logistic Regression Vectorization Example, which uses a related (but slightly different) vectorization step to get to the final implementation. Using a similar method, you will also be able to compute gradW1 with a single line of code.
When you complete the derivation, you should be able to replace the unvectorized backpropagation code example above with just 4 lines of Matlab/Octave code.