WEBVTT

00:00.270 --> 00:02.220
-: Hello and welcome to this tutorial.

00:02.220 --> 00:03.090
In today's tutorial,

00:03.090 --> 00:06.120
we are going to synchronize with the shared model.

00:06.120 --> 00:07.860
So what we're gonna do

00:07.860 --> 00:10.290
is still saying the function, of course,

00:10.290 --> 00:14.190
and then initialize the length of one episode.

00:14.190 --> 00:16.740
So we're gonna call the length of an episode,

00:16.740 --> 00:20.220
episode on this core length.

00:20.220 --> 00:21.180
There we go.

00:21.180 --> 00:23.880
And we are going to initialize it to zero

00:23.880 --> 00:27.510
but then this episode length will be incremented.

00:27.510 --> 00:29.100
And speaking of incrementing it,

00:29.100 --> 00:30.630
that's exactly what we'll do.

00:30.630 --> 00:35.630
So we're gonna use a while loop and use this trick to say,

00:36.360 --> 00:41.340
while true colon, to repeat what's gonna happen now,

00:41.340 --> 00:44.160
what's gonna happen inside this while loop.

00:44.160 --> 00:45.600
And so the first thing that's gonna happen

00:45.600 --> 00:46.590
in this while loop

00:46.590 --> 00:50.220
is this incrementation of the length of an episode.

00:50.220 --> 00:51.750
So the first thing that we're gonna do

00:51.750 --> 00:53.910
is increment it by one.

00:53.910 --> 00:58.290
And to do so, we can simply take episode length

00:58.290 --> 01:03.180
and add here, plus equals one.

01:03.180 --> 01:06.510
And now, we are going to synchronize with the shared model.

01:06.510 --> 01:09.340
That means that it is now that the agent

01:10.335 --> 01:13.680
will use the shared model to do its little exploration

01:13.680 --> 01:15.510
on a certain number of steps.

01:15.510 --> 01:18.840
And how is the model going to get this shared model?

01:18.840 --> 01:21.750
Well, we need to take our model then dot,

01:21.750 --> 01:26.750
and then use the load state dict method

01:27.180 --> 01:32.180
because we're gonna use it to get the state dictionary

01:32.190 --> 01:34.260
of our shared model.

01:34.260 --> 01:37.170
So we have to put the shared model first

01:37.170 --> 01:40.920
and apply then the state dict method to get the parameters

01:40.920 --> 01:41.910
of the shared model.

01:41.910 --> 01:45.450
And that's how our model here will get the shared model

01:45.450 --> 01:47.073
to do its little exploration.

01:48.180 --> 01:51.420
Okay, and once the model gets this shared model,

01:51.420 --> 01:53.820
now we have to distinguish two cases.

01:53.820 --> 01:58.820
The first one is if done, meaning if the game is done.

02:00.180 --> 02:03.510
So if the game is done, then what happens in that case?

02:03.510 --> 02:06.570
Well, we have to reinitialize the hidden state

02:06.570 --> 02:09.930
and the cell state of the LSTM in the model.

02:09.930 --> 02:14.070
And so that's why now I'm gonna take cx, the cell state

02:14.070 --> 02:16.980
and also hx, the hidden state,

02:16.980 --> 02:19.170
and I'm going to reinitialize them both.

02:19.170 --> 02:21.120
And how we're going to reinitialize them?

02:21.120 --> 02:23.130
Well, with only zeros.

02:23.130 --> 02:26.640
There will be a vector of 256 zeros

02:26.640 --> 02:29.220
because remember, the outputs of the LSTM

02:29.220 --> 02:31.800
has dimensions 1 and 256.

02:31.800 --> 02:32.633
So there we go.

02:32.633 --> 02:36.300
We're going to initialize them by using the torch library

02:36.300 --> 02:39.150
then the zeros function.

02:39.150 --> 02:42.210
And since we want a vector of 256 zeros,

02:42.210 --> 02:45.330
we are gonna input here the dimensions one

02:45.330 --> 02:49.350
for the vector, and 256 for the number of elements

02:49.350 --> 02:50.790
which will be zeros.

02:50.790 --> 02:51.780
And there we go.

02:51.780 --> 02:56.160
But then we will convert that into a torch variable

02:56.160 --> 02:58.620
because then some gradients will be computed.

02:58.620 --> 03:02.010
So we need to integrate this with a gradient.

03:02.010 --> 03:02.843
All right?

03:02.843 --> 03:07.470
And we're gonna do the same for the hidden state just below

03:07.470 --> 03:09.930
and reinitialize them the same way.

03:09.930 --> 03:10.770
There we go.

03:10.770 --> 03:13.350
So that's if the game is done.

03:13.350 --> 03:18.150
And now, the other case which we can access with else.

03:18.150 --> 03:20.940
Else, then what happens in that case?

03:20.940 --> 03:23.910
Well, we're gonna keep the old cell states

03:23.910 --> 03:25.230
and hidden states.

03:25.230 --> 03:28.260
And so very easily, we can keep the old ones this way

03:28.260 --> 03:33.260
by typing cx equals variable, cx dot data.

03:34.470 --> 03:36.780
And same for the hidden states,

03:36.780 --> 03:41.780
we can simply add here hx equals variable, hx dot data.

03:45.090 --> 03:46.590
All right, good thing done.

03:46.590 --> 03:49.290
Now we can get out of the else

03:49.290 --> 03:52.260
because we are basically done with these two cases,

03:52.260 --> 03:54.150
whether the game is over or not.

03:54.150 --> 03:55.710
But we stay in the while loop

03:55.710 --> 03:57.330
because now, we're gonna do some more things

03:57.330 --> 04:00.510
which basically are all the training process.

04:00.510 --> 04:01.950
And so what we're gonna do now

04:01.950 --> 04:04.530
is initialize several variables,

04:04.530 --> 04:07.170
which are gonna be at the heart of the computations

04:07.170 --> 04:08.160
in the training.

04:08.160 --> 04:08.993
So let's do this.

04:08.993 --> 04:10.830
We're gonna need the values,

04:10.830 --> 04:13.710
which remember is the output of the critic.

04:13.710 --> 04:15.150
So that's the v function

04:15.150 --> 04:19.170
and we will initialize them as an empty list this way.

04:19.170 --> 04:23.970
Then we're gonna need the log probabilities, so log probs,

04:23.970 --> 04:27.540
and we will also initialize it as an empty list.

04:27.540 --> 04:30.060
Then, of course, we're gonna need a reward

04:30.060 --> 04:33.630
that we will also initialize as an empty list.

04:33.630 --> 04:37.860
And finally, we're gonna need the entropies.

04:37.860 --> 04:41.700
Something new but this is indeed at the heart

04:41.700 --> 04:43.260
of the training computations.

04:43.260 --> 04:45.150
So empty list as well.

04:45.150 --> 04:47.550
So now that we initialize these four variables,

04:47.550 --> 04:49.552
we can start a new for loop.

04:49.552 --> 04:50.385
And then this new for loop,

04:50.385 --> 04:53.443
we will update the values of these four variables.

04:53.443 --> 04:56.160
And so this new for loop is gonna be a for loop

04:56.160 --> 04:57.810
over the exploration steps.

04:57.810 --> 04:59.730
And therefore, the looping variable

04:59.730 --> 05:01.530
is going to be our steps.

05:01.530 --> 05:04.650
So for step in range

05:04.650 --> 05:09.650
and inside, we can directly input params dot numsteps

05:10.650 --> 05:13.620
because params dot numsteps is exactly the number of steps

05:13.620 --> 05:15.180
of the exploration.

05:15.180 --> 05:19.350
So for all the steps in the exploration, what do we do?

05:19.350 --> 05:22.860
Well, we are gonna get the predictions of the model.

05:22.860 --> 05:24.537
You know, what is returned by the model.

05:24.537 --> 05:28.200
And to get these predictions, we can simply take the model

05:28.200 --> 05:29.940
and apply it to the input.

05:29.940 --> 05:31.350
So that's the input signal.

05:31.350 --> 05:33.810
It goes through the brains in the model

05:33.810 --> 05:35.580
and that will get us the outputs,

05:35.580 --> 05:37.530
but it will get us several outputs, you know?

05:37.530 --> 05:39.840
It will get us the values of the v function

05:39.840 --> 05:42.180
which is the output of the critic.

05:42.180 --> 05:46.170
Then the Q-values, QSA, which is the output of the actor.

05:46.170 --> 05:48.630
But also don't forget that it will also output

05:48.630 --> 05:51.630
the tuple of the hidden state and cell state.

05:51.630 --> 05:54.090
Because remember, if we go back to our model,

05:54.090 --> 05:56.040
well, in the forward function,

05:56.040 --> 06:00.300
we can see that indeed it returns the output of the critic,

06:00.300 --> 06:03.300
that is the value of the V function, VS,

06:03.300 --> 06:07.950
then the output of the actor, which are the Q-values, QSA,

06:07.950 --> 06:10.560
and also the outputs of the LSTM,

06:10.560 --> 06:12.810
which is this tuple hx and cx,

06:12.810 --> 06:14.880
the hidden state and the cell state.

06:14.880 --> 06:16.920
So we must be careful with that.

06:16.920 --> 06:19.800
This is quite different than what happened before.

06:19.800 --> 06:22.140
And therefore, we're now going to apply the model

06:22.140 --> 06:24.540
to the input, which is a state.

06:24.540 --> 06:26.760
But now there are several things to do,

06:26.760 --> 06:29.280
which are related to torch, but that gives, of course,

06:29.280 --> 06:31.170
power to what we're doing.

06:31.170 --> 06:36.170
The first thing we need to do is to unsqueeze the state

06:36.480 --> 06:40.920
to add this fake dimension that must have the index zero.

06:40.920 --> 06:42.840
That's because the model can only accept

06:42.840 --> 06:45.300
a batch of inputs and not an input by itself

06:45.300 --> 06:47.160
in a vector or a tensor.

06:47.160 --> 06:49.650
So that's the first thing we must do, the unsqueeze,

06:49.650 --> 06:50.850
but then that's not all,

06:50.850 --> 06:55.500
we need to convert our input state into a torch variable.

06:55.500 --> 06:59.160
So I'm adding here the variable.

06:59.160 --> 07:01.890
So now we're okay with the state, the input state,

07:01.890 --> 07:05.220
but remember that the inputs of the forward functions

07:05.220 --> 07:06.960
are actually the inputs image,

07:06.960 --> 07:08.580
so that's what we just took care of,

07:08.580 --> 07:11.430
but also this tuple of hx, the hidden states,

07:11.430 --> 07:13.770
and cx, the cell states.

07:13.770 --> 07:16.500
And therefore, we need to add here

07:16.500 --> 07:21.500
this second part of the input with the tuple of hx and cx.

07:23.160 --> 07:26.400
All right, and we must take care of the parenthesis.

07:26.400 --> 07:28.470
There we go, we have our two inputs.

07:28.470 --> 07:30.270
The first one is the input state.

07:30.270 --> 07:31.830
That is the input images,

07:31.830 --> 07:33.930
all converted into a torch variable

07:33.930 --> 07:37.200
and unsqueezed to add this fake dimension of the batch,

07:37.200 --> 07:40.380
and this tuple of the hidden state and the cell state.

07:40.380 --> 07:41.640
So we're all good to go.

07:41.640 --> 07:44.070
We are ready to get our predictions.

07:44.070 --> 07:48.180
And now, since this return, well, our three predictions,

07:48.180 --> 07:50.790
the output of the critic, the output of the actor,

07:50.790 --> 07:53.310
and the tuple of the hidden state and the cell state

07:53.310 --> 07:54.600
by the LSTM,

07:54.600 --> 07:57.990
well, we're going to introduce some three new variables now

07:57.990 --> 07:59.910
which will be these three outputs.

07:59.910 --> 08:00.743
So there we go.

08:00.743 --> 08:03.810
The first output is the value of the v function

08:03.810 --> 08:05.490
which is the output of the critic.

08:05.490 --> 08:08.430
So we're gonna call it value.

08:08.430 --> 08:09.263
So there we go.

08:09.263 --> 08:10.320
That's the first output.

08:10.320 --> 08:12.960
Then the second output is going to be

08:12.960 --> 08:17.130
the output of the actor, and that's the Q-values, QSA.

08:17.130 --> 08:20.430
But since the Q-values are associated to the actions,

08:20.430 --> 08:23.793
we can also call them the action values.

08:24.720 --> 08:25.710
All right.

08:25.710 --> 08:28.860
And then final output returned by the model

08:28.860 --> 08:32.040
that's the tuple of the hidden state, hx,

08:32.040 --> 08:34.170
and the cell state, cx.

08:34.170 --> 08:35.040
And there we go.

08:35.040 --> 08:39.210
We have our three outputs returned by the model.

08:39.210 --> 08:40.200
Perfect.

08:40.200 --> 08:41.910
So now that we have the predictions,

08:41.910 --> 08:45.690
we need to use a softmax to play the right action.

08:45.690 --> 08:47.310
And so now that's gonna be exactly the same

08:47.310 --> 08:48.570
as what we did before.

08:48.570 --> 08:51.960
The next step is to get our probabilities

08:51.960 --> 08:53.913
so we can call them prob.

08:54.870 --> 08:58.020
And that's where we use the softmax method,

08:58.020 --> 09:00.720
which we take from the functional module

09:00.720 --> 09:03.363
that has the shortcut f, so f.softmax.

09:05.220 --> 09:08.340
And that will generate a distribution of probabilities

09:08.340 --> 09:11.940
of the input that we're about to input right now,

09:11.940 --> 09:14.130
and which of course, the action values,

09:14.130 --> 09:18.180
that is the Q-values, that is the output of the actor

09:18.180 --> 09:19.170
in the model.

09:19.170 --> 09:22.320
Okay, so now we have our probabilities, but as you noticed,

09:22.320 --> 09:24.510
we're gonna work with the entropy.

09:24.510 --> 09:25.590
And to get the entropy,

09:25.590 --> 09:27.570
we not only need the probabilities,

09:27.570 --> 09:29.820
but also the log probabilities

09:29.820 --> 09:32.910
because the entropy is the sum of the product,

09:32.910 --> 09:37.380
log prob times prob, all this multiplied by minus one.

09:37.380 --> 09:41.460
And so we also need to get our log prob,

09:41.460 --> 09:46.440
which same are going to be generated from log softmax.

09:46.440 --> 09:49.290
So instead of taking a distribution of the probabilities,

09:49.290 --> 09:51.810
we take a distribution of the log probabilities,

09:51.810 --> 09:55.283
and that we do it with the log softmax function

09:58.770 --> 10:01.360
to same, we apply to the Q-values

10:02.460 --> 10:04.470
which we call the action values.

10:04.470 --> 10:07.590
All right, so now we have the prob and the log prob,

10:07.590 --> 10:10.170
and so we are ready to get the entropy.

10:10.170 --> 10:11.223
And the entropy,

10:12.240 --> 10:13.770
what is the formula for that?

10:13.770 --> 10:17.193
Well, as I just mentioned, we take the log prob,

10:18.150 --> 10:20.160
we multiply it by the prob,

10:20.160 --> 10:23.070
then we're gonna take the sum of all this.

10:23.070 --> 10:27.840
And to do that, we can add here dot sum one.

10:27.840 --> 10:30.210
We actually used the streak many times now.

10:30.210 --> 10:33.720
And as we said, we multiply all this by minus one.

10:33.720 --> 10:37.140
So it's the minus of the sum of the product,

10:37.140 --> 10:39.120
log prob times prob.

10:39.120 --> 10:39.953
Perfect.

10:39.953 --> 10:42.240
And now, we are gonna store this entropy

10:42.240 --> 10:45.480
that was just computed in our list of entropies,

10:45.480 --> 10:46.313
because there we go,

10:46.313 --> 10:48.570
we have the last computation of the entropy,

10:48.570 --> 10:51.690
and so we need to store it in the entropies list.

10:51.690 --> 10:53.250
And to do this, nothing more simple,

10:53.250 --> 10:55.320
we're gonna use the append function, of course,

10:55.320 --> 10:57.480
because entropies is a list.

10:57.480 --> 11:00.510
So we take our entropies list then dot

11:00.510 --> 11:04.470
and we use the append function to add the entropy

11:04.470 --> 11:05.620
that was just computed.

11:06.540 --> 11:08.460
All right, so we're gonna take a break now.

11:08.460 --> 11:10.170
We're gonna do this step by step.

11:10.170 --> 11:12.300
In the next tutorial, we will play the action

11:12.300 --> 11:15.600
by taking a random draw of this generated distribution

11:15.600 --> 11:17.010
of probabilities.

11:17.010 --> 11:18.270
And after we play the action,

11:18.270 --> 11:19.980
we will get the value of the state,

11:19.980 --> 11:23.190
and we will eventually store our new transition,

11:23.190 --> 11:25.140
state, reward, and done.

11:25.140 --> 11:26.910
So that will be a new big step done

11:26.910 --> 11:29.550
and we will complete that in the next tutorial.

11:29.550 --> 11:31.353
Until then, enjoy AI.
