Derivative of Softmax under Cross Entropy Error
This derivation process used to stuck me for a while, so I plan to share this to others who may also encounter this problem.
First, a quick overview of how softmax work and what is cross-entropy error function:
Softmax
Here is the equation how we compute \(y_i\), which is the probability class \(i\) occured, given its penalty \(z_i\):
where \(K\) is the number of all possible classes, \(i\) is the class we want.
Cross Entorpy
The equation below compute the cross entropy \(C\) over softmax function:
where \(K\) is the number of all possible classes, \(t_k\) and \(y_k\) are the target and the softmax output of class \(k\) respectively.
Derivation
Now we want to compute the derivative of \(C\) with respect to \(z_i\), where \(z_i\) is the penalty of a particular class \(i\). We can expand the derivative like this(simply because \(t_k\) does not depended on \(z_i\), and the summation notation \(\sum\) means independent relation among its terms):
By applying Chain Rule, we can simplify the above equation:
Assume that we already knew \(\frac{\partial y_k}{\partial z_i}\)(more details will be presented later) is following:
where \(1\{k = i\}\) equals to \(1\) if \(k = i\) and \(0\) otherwise.
Plug that into the previous equation, we get:
Note that the subscript \(i\) of \(y_i\) is constant and \(1\{k = i\}\) only have non-zero value if \(k = i\), finally we can get:
It is almost done. If you are not interested in how to get \(\frac{\partial y_k}{\partial z_i}\) or you don’t have enough time, now you can leave this page(because the remaining part is boring).
Appendix
Now I describe how to compute \(\frac{\partial y_k}{\partial z_i}\).
Before you start, forgetting the original meanings of notations we used before may helps to avoiding confusion(like notations:\(k\) and \(i\)).
Softmax recap:
Beacause \(k\) may takes all its possible values, so we manually divide its values into two sets:
-
\(k = i\)
By applying Quotient Rule:
where \(\sum\) here represent \(\sum_{j=1}^J\).
-
\(k \not= i\)
Applying Quotient Rule again, but notice that the numerator does not depended on \(z_i\):
here \(\sum\) also denote \(\sum_{j=1}^J\).
By carefully watching the \(=\) and \(\not=\) notations, we can wrap this two cases up into this much more compact form:
where \(1\{k = i\}\) equals to \(1\) if \(k = i\) and \(0\) otherwise.
You can verify this equation by hand.
Right now all derivations are done.