Before we dive in, I would like to point out that all the basic ideas in this article are based on the paper — “SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient (Lantao Yu et. al.)”. I have left out details that aren’t necessary for understanding the intuition behind the approach and its working. However, I would definitely recommend that you read the paper as well.
Trending AI Articles:
1. Deep Learning Book Notes, Chapter 1
2. Deep Learning Book Notes, Chapter 2
3. Machines Demonstrate Self-Awareness
4. MS or Startup Job — Which way to go to build a career in Deep Learning?
Formulating the problem
Let’s leave the whole idea of neural network based text generators aside for a moment and think of our text generator as an RL agent. What would its states and actions be? One simple formulation would be to define the state s of the agent to be the “text generated so far” and the action a to be “choosing the next word” in the sentence. Therefore, the choice of actions is defined by our vocabulary of words.
The generation proceeds in this manner with the agent choosing an action (word) at each time step. When the agent finally chooses the “end of sentence” action, it reaches the end of an episode and receives a reward which tells it how good its state-action sequence (generated sentence) was. In our case, this reward will be provided by the discriminator network. So, we can view the discriminator as a continuously improving reward function for our agent.
Now, how does our agent decide which action to pick at a particular state? This is where the concept of a policy comes in.
A policy is a function (a | s, ) parameterized by , that outputs a probability distribution over all the actions a given the current state s of the agent. An agent generating text based on a policy samples from the set of actions according to the distribution returned by the policy (As a side-note, this gives us some built-in non-determinism that helps our agent in exploring state space, i.e., it doesn’t just generate the same high reward sentence over and over).
Building a good text generation agent now boils down to the task of finding the optimal policy *. In other words, we want to find the optimal parameter * that maximizes some (will be defined later) performance measure J() of our agent.
Policy gradient methods are a particular class of policy optimization methods for finding the optimal parameter *, that are based on finding the gradient of some performance measure J() with respect to . So, in order to find the optimal parameter, we perform gradient ascent on J as:
where the caret (^) on the gradient signifies that it is an approximation of the gradient.
So, how do we go about approximating the gradient? The answer is provided by the policy gradient theorem. The derivation of the theorem is fairly interesting, but I think getting an intuitive idea will suffice for the purposes of this article.
The policy gradient theorem gives us an analytical expression for the gradient of J with respect to :
- (s) is a measure of how often the agent is likely to visit state s (across a lot of episodes) if it follows policy . It is normalized by its sum over all states to give us a probability measure. You might be wondering how we could possibly calculate that, but don’t worry, it conveniently disappears.
- q(s, a) is a measure of the “quality” (can be thought of as value) of action a taken from state s.
Now, the RHS of the policy gradient theorem is a summation over all states and all actions. But, we want an efficient algorithm that will estimate this quantity using just a single sample (S_t, A_t), i.e., the state and action taken at time step t while following the policy . This is where the REINFORCE algorithm for calculating policy gradients comes in handy.
The REINFORCE algorithm
Let’s look at the expression for the gradient again:
Notice that (s) is essentially the probability of state s occurring while following policy . So, the outer summation is just an expectation (over ) of the value of the inner summation.
So why not do the same thing for the inner summation? To do this, we multiply and divide by (a | S_t, ).
(a | S_t, ) is the probability distribution over the actions at state S_t. So again, the summation looks like an expectation and we can reduce it to:
That’s it! We can estimate this form of the gradient using just the one sample (S_t, A_t) that we get at time step t while generating text using the current policy . Replacing q_(S_t, A_t) with G_t, we get the final form of the REINFORCE update.
Ok, what’s G_t? It can be thought of as the cumulative reward that we can expect while following the policy from time step t until the end of the episode.
Understanding the REINFORCE update
Intuitively, (A_t | S_t, _t) is a vector that tells us the direction of maximum increase of the probability of taking action A_t when we encounter state S_t again. So, it makes sense that we move our parameter in that direction.
The amount by which we move in that direction is proportional to G_t, which basically means that if our expected reward is higher, we move by a larger amount in that direction.
The final REINFORCE algorithm
Connecting this to text GANs
Now that we have all the pieces we need, we are finally ready to construct a GAN for text generation.
Let’s look at the easiest part first — the discriminator network. The discriminator network simply takes a sentence as input and outputs a value that signifies how “real” the sentence looks. This can be done using a CNN/RNN based model. Note that this value output by the discriminator will also be used as the reward corresponding to the input sentence, and will be provided back to the generator to update its policy at the end of each episode.
What about the generator? It is going to represent the parameterized policy (a | s, ), where corresponds to the parameters of the generator network which will be updated according to the REINFORCE algorithm given above. Therefore, the generator should be able to take a sequence of words as input and output a probability distribution over the next word. This can be achieved using an RNN by feeding in a word at each time step and obtaining the final hidden state. The final hidden state of the RNN is then mapped to a probability distribution using a feed-forward network followed by softmax.
There is still one piece left to clarify — the intermediate rewards R_k that we use to calculate G_t. As I mentioned above, the discriminator network provides a reward only at the end of the episode, i.e., after it looks at the entire generated sentence. So, where do we get the intermediate rewards from? The SeqGAN paper solves this issue using Monte-Carlo rollouts. Hold on, you don’t have to know what that means, the idea is actually pretty straightforward.
The idea behind Monte-Carlo rollouts is to start from the current state S_t, and simulate N (10–20) different episodes following the current policy . In our case, this means generating N different sentences whose prefix is the incomplete sentence corresponding to state S_t.
For example, let S_t = “<s> The blue house”
We then produce rollouts such as:
- “<s> The blue house is across the road </s>”
- “<s> The blue house is large </s>”
- “<s> The blue house belongs to Ms. Smith </s>”
And so on.
Then, we get the rewards corresponding to these N sentences using the discriminator and average them out to get the expected reward R_t at time step t. This makes sense because it tells us what reward to expect when we proceed till the end from state S_t.
The issue with RL-based methods
Since we are just using a few samples to estimate the gradient of the policy at each time step, there is a very high variance in the gradient estimate from episode to episode. This makes the training process unstable and convergence is very slow. The SeqGAN paper attempts to speed-up the training by pre-training both the generator and discriminator as standard language models using MLE.
Also, policy gradient methods tend to converge to a local maxima, especially in cases such as ours where the state-action space is huge. Note that we have a choice between |V| actions at each time step, where V is our vocabulary (could be of the order of 100,000).
In the next post (will come up soon) we will be looking at several methods that don’t use RL at all and attempt to solve the problem by avoiding working with discrete spaces altogether.