WEBVTT

00:00.240 --> 00:02.790
-: Hello and welcome to this Python tutorial.

00:02.790 --> 00:05.940
All right, so today, we will be making the update function,

00:05.940 --> 00:08.850
which will update everything there is to update

00:08.850 --> 00:12.480
as soon as the AI reaches a new state.

00:12.480 --> 00:13.920
So when it reaches a new state,

00:13.920 --> 00:16.170
you know we need to update the action,

00:16.170 --> 00:19.530
the last action becomes the new action that was displayed,

00:19.530 --> 00:22.650
but also, the last state that becomes the new state,

00:22.650 --> 00:24.180
and finally, the last reward

00:24.180 --> 00:27.450
that becomes the new reword we get when we play the action.

00:27.450 --> 00:29.310
So that's the logical path

00:29.310 --> 00:31.530
that happens right after selecting an action.

00:31.530 --> 00:35.250
We need to update all the elements of the transitions.

00:35.250 --> 00:37.380
And of course, we will get a new transition,

00:37.380 --> 00:40.320
so we will have to append this new transition to the memory.

00:40.320 --> 00:43.920
And finally, we will also update our reward window

00:43.920 --> 00:47.190
to keep an eye on the evolution of how the training's going

00:47.190 --> 00:49.650
and how the exploration is going.

00:49.650 --> 00:51.990
But what's most important for you to understand

00:51.990 --> 00:54.840
is that now we can finally make a connection

00:54.840 --> 00:59.040
between the AI that we're implementing right now to our map.

00:59.040 --> 01:01.500
Because, if we go back to our map, remember,

01:01.500 --> 01:06.210
there is this big update function into the game class,

01:06.210 --> 01:09.570
so that's where we're actually making the game with the car

01:09.570 --> 01:11.610
and defining how the car should be punished

01:11.610 --> 01:13.170
when it's making a mistake.

01:13.170 --> 01:16.860
But in this game class, we notice this update function.

01:16.860 --> 01:19.500
And in this update function, we notice this line,

01:19.500 --> 01:24.500
action=brain.update, last reward, last signal.

01:24.870 --> 01:28.470
And actually, this is exactly what we're about to make.

01:28.470 --> 01:31.350
We are about to make this update function

01:31.350 --> 01:35.460
that will take the last reward and the last signal

01:35.460 --> 01:37.830
to get the next action to play.

01:37.830 --> 01:39.510
So not only will we update

01:39.510 --> 01:42.990
all the different elements of the transition, but mostly,

01:42.990 --> 01:46.470
we will be playing the action that we should play

01:46.470 --> 01:49.410
when getting the last reward and the last signal.

01:49.410 --> 01:51.990
And so, of course, in this update function,

01:51.990 --> 01:55.350
we will use the Select Action function

01:55.350 --> 01:57.240
that we just implemented before.

01:57.240 --> 02:00.270
We will integrate the Select Action function

02:00.270 --> 02:03.540
in the future update function that we're about to make

02:03.540 --> 02:05.700
to select the right action to play

02:05.700 --> 02:07.830
besides making all the updates.

02:07.830 --> 02:09.960
So that's really important to make this connection

02:09.960 --> 02:11.280
with the map right now.

02:11.280 --> 02:13.710
What we're about to make is, eventually,

02:13.710 --> 02:17.820
the connection between our AI and the game,

02:17.820 --> 02:19.560
the game that we make in this class.

02:19.560 --> 02:24.060
And so, what we can do now is directly take this update,

02:24.060 --> 02:25.680
last reward, last signal,

02:25.680 --> 02:28.830
because that's exactly the function that we will be making

02:28.830 --> 02:30.690
with these two arguments here.

02:30.690 --> 02:34.470
And just as a quick reminder, brain is our AI object,

02:34.470 --> 02:38.520
that is, it's the object of the DQN class.

02:38.520 --> 02:41.370
So what we're gonna do now is we are going to copy this,

02:41.370 --> 02:44.310
update, last reward, last signal,

02:44.310 --> 02:48.450
and that's gonna be our next function we're making.

02:48.450 --> 02:50.490
And therefore, I'm pasting that here.

02:50.490 --> 02:52.320
Then, just to be careful, I would just like

02:52.320 --> 02:56.340
to give some different names than the names we have here.

02:56.340 --> 02:57.480
We have last reward here

02:57.480 --> 03:00.870
and I don't wanna confuse this last reward with this one.

03:00.870 --> 03:02.190
That can be dangerous.

03:02.190 --> 03:06.390
So I'm going to replace last reward here by reward

03:06.390 --> 03:08.910
and by doing the same for last signal.

03:08.910 --> 03:13.200
Let's just put signal or even new signal to specify

03:13.200 --> 03:16.315
that we want to make the update when reaching a new state,

03:16.315 --> 03:18.900
and therefore, getting a new signal.

03:18.900 --> 03:21.120
But then, of course, this reward here

03:21.120 --> 03:26.120
is going to be the last reward that we get here.

03:26.580 --> 03:29.550
When going on to some sand or worse,

03:29.550 --> 03:31.500
getting too close to an edge of the map.

03:31.500 --> 03:34.170
That's where we define the last reward.

03:34.170 --> 03:36.330
This last reward is going to be the input

03:36.330 --> 03:37.380
of the update function,

03:37.380 --> 03:39.570
so that's why we have last reward here.

03:39.570 --> 03:42.630
But right here, I'm just giving another name

03:42.630 --> 03:44.490
for the argument reward

03:44.490 --> 03:47.610
to not confuse it with last reward here.

03:47.610 --> 03:50.280
All right, so this is the update function.

03:50.280 --> 03:54.060
And now, let's go inside it and let's do these two things,

03:54.060 --> 03:56.640
that is update all the elements of our transition

03:56.640 --> 03:59.250
and, of course, select the action.

03:59.250 --> 04:01.920
Okay, so what do we need to update first?

04:01.920 --> 04:05.070
Well, as you understood, we want to make the updates

04:05.070 --> 04:06.840
when reaching a new state.

04:06.840 --> 04:08.790
So the first thing we'll be updating is,

04:08.790 --> 04:10.830
obviously, this new state.

04:10.830 --> 04:12.810
That is the new state we're reaching.

04:12.810 --> 04:17.310
So I'm gonna call this new state new state and then equal.

04:17.310 --> 04:20.010
And so, how can we get this new state?

04:20.010 --> 04:22.440
Well, of course, that depends on the signal.

04:22.440 --> 04:25.830
The new signal that the sensors just detected.

04:25.830 --> 04:29.070
And as a reminder, the state is the signal itself

04:29.070 --> 04:32.340
composed of the three signals of the sensors,

04:32.340 --> 04:34.800
signal one, signal two and signal three,

04:34.800 --> 04:37.500
plus orientation and minus orientation.

04:37.500 --> 04:40.230
That's our state, so be sure to understand

04:40.230 --> 04:42.660
that the signal is the state.

04:42.660 --> 04:46.170
But right now, it is a simple list of five elements.

04:46.170 --> 04:47.900
And since this is going to be the input

04:47.900 --> 04:49.650
of the neural network, remember,

04:49.650 --> 04:52.470
we have to convert it into a torch tensor.

04:52.470 --> 04:54.900
So that's exactly what we're gonna do right now.

04:54.900 --> 04:59.160
We are going to take our torch library

04:59.160 --> 05:02.520
and then take the tensor class, there we go,

05:02.520 --> 05:07.520
which will convert our new signal into a torch tensor.

05:09.150 --> 05:10.680
Then, it's better to make sure

05:10.680 --> 05:14.070
that all the elements of the torch tensor are floats.

05:14.070 --> 05:16.500
So I'm going to make a type conversion

05:16.500 --> 05:19.860
to convert them into float, like this.

05:19.860 --> 05:22.590
And then, finally, try to get the reflex

05:22.590 --> 05:24.480
of what we need to do next,

05:24.480 --> 05:26.820
it's, of course, to create that fake dimension

05:26.820 --> 05:29.160
to add the dimension corresponding to the batch.

05:29.160 --> 05:33.630
And we do this, of course, with the Unsqueeze function

05:33.630 --> 05:37.260
to which we have to input the index of this fake dimension

05:37.260 --> 05:40.710
we want to have for the batch, which is zero.

05:40.710 --> 05:43.170
All right, and now we have our new state

05:43.170 --> 05:45.930
composed of the three signals of the three sensors,

05:45.930 --> 05:48.360
plus orientation, minus orientation.

05:48.360 --> 05:50.940
And of course, that will depend on the new signal

05:50.940 --> 05:55.560
we are getting with this update function right at this time.

05:55.560 --> 05:57.810
Last signal, we get the three signals,

05:57.810 --> 05:59.820
plus orientation, minus orientation.

05:59.820 --> 06:01.860
And as a reminder, the three signals

06:01.860 --> 06:06.450
are the density of sand detected around the sensors.

06:06.450 --> 06:09.090
All right, so we just got our new state,

06:09.090 --> 06:11.670
so that means we reached the new state.

06:11.670 --> 06:13.710
And now, we have to make the next update.

06:13.710 --> 06:16.650
So according to you, what do we need to update now?

06:16.650 --> 06:20.670
What would be the logical thing to update right now,

06:20.670 --> 06:22.620
after reaching this new state?

06:22.620 --> 06:25.950
Well, what we need to update now is the memory.

06:25.950 --> 06:26.970
Why is that?

06:26.970 --> 06:28.980
It's because, at each time T,

06:28.980 --> 06:32.820
a transition is composed of the current state, ST,

06:32.820 --> 06:37.800
the next state, ST + 1, the reward, RT, and the action, AT.

06:37.800 --> 06:40.200
And right now, we already have ST,

06:40.200 --> 06:43.200
we already have RT and we already have AT.

06:43.200 --> 06:47.550
And we just got the last element of the transition, ST + 1.

06:47.550 --> 06:51.510
So by getting this new state, ST + 1,

06:51.510 --> 06:55.560
we are getting one brand new transition of the memory.

06:55.560 --> 06:58.530
And therefore, we have to append this brand new transition

06:58.530 --> 07:01.920
to the memory because that's simply our next transition.

07:01.920 --> 07:04.080
So that's why we have to update the memory right now.

07:04.080 --> 07:08.580
And therefore, what I'm gonna do is take my memory object

07:08.580 --> 07:10.800
created from the replay memory class,

07:10.800 --> 07:15.450
and therefore, I'm gonna take self.memory

07:15.450 --> 07:17.130
to refer to the object.

07:17.130 --> 07:18.690
But since I'm using self,

07:18.690 --> 07:22.890
I have to include the self in the update function.

07:22.890 --> 07:25.500
So now, you can really see what the self is for.

07:25.500 --> 07:27.810
It's whenever you use one variable

07:27.810 --> 07:31.620
that you created and initialized in the init function.

07:31.620 --> 07:34.170
So self.memory, and now we need to update it.

07:34.170 --> 07:37.140
And, according to you, how are we going to update that?

07:37.140 --> 07:38.280
Well, the good news is

07:38.280 --> 07:41.280
that we already made a function to do that.

07:41.280 --> 07:42.780
It's the Push function,

07:42.780 --> 07:47.160
which appends an event or a transition to the memory.

07:47.160 --> 07:49.050
So that's exactly what we're gonna use now.

07:49.050 --> 07:50.820
We're gonna use the Push function

07:50.820 --> 07:53.790
to append our new transition that we just made

07:53.790 --> 07:54.660
to the memory.

07:54.660 --> 07:58.410
And therefore, here, I'm taking not an equal

07:58.410 --> 08:00.330
because we're gonna use a method,

08:00.330 --> 08:04.230
and therefore, we can directly use .push.

08:04.230 --> 08:06.552
And first, I'm going to add the transition,

08:06.552 --> 08:08.910
this new transition that we just got,

08:08.910 --> 08:10.830
and that is first, the last state.

08:10.830 --> 08:14.250
So self.laststate.

08:14.250 --> 08:16.170
So that's ST.

08:16.170 --> 08:18.660
That's exactly this one, it already exists.

08:18.660 --> 08:21.270
Then, the next element of this transition is, of course,

08:21.270 --> 08:23.820
the new state that we just reached.

08:23.820 --> 08:26.483
And therefore, since it is not a variable of the object

08:26.483 --> 08:30.060
that we created and initialized in this init function,

08:30.060 --> 08:31.560
we don't put a self here.

08:31.560 --> 08:33.603
We directly put the new state.

08:35.370 --> 08:38.940
Then, the next element of the transition is the action.

08:38.940 --> 08:40.860
And say, we already have the last action,

08:40.860 --> 08:43.830
which is this self.lastaction here.

08:43.830 --> 08:46.230
So, of course, it is equal to zero, but then, of course,

08:46.230 --> 08:50.130
it will be updated with the Select Action function.

08:50.130 --> 08:53.780
But that's this one, so then it is self.lastaction.

08:55.860 --> 08:57.150
But now, be careful.

08:57.150 --> 08:59.730
The elements that we're including in this transition

08:59.730 --> 09:01.860
should all be torch tensors.

09:01.860 --> 09:03.870
As you can see, that's the case for the last state.

09:03.870 --> 09:05.310
It's a torch tensor.

09:05.310 --> 09:07.860
The new state is also a torch tensor.

09:07.860 --> 09:10.590
And so, this must be the same for the action

09:10.590 --> 09:12.180
and then the reward, of course.

09:12.180 --> 09:15.337
But now, you're gonna think, "How can it be a torch tensor,

09:15.337 --> 09:18.030
"considering that it's simply a number?"

09:18.030 --> 09:20.880
The action is either zero, one or two.

09:20.880 --> 09:22.350
But, in fact, that's not a problem.

09:22.350 --> 09:26.790
We can still convert this zero, one or two variable

09:26.790 --> 09:28.650
into a torch tensor.

09:28.650 --> 09:31.530
This will just be what we call a long tensor.

09:31.530 --> 09:33.990
The long is a type and that's the tensor

09:33.990 --> 09:35.850
that will contain an integer

09:35.850 --> 09:37.560
because the last action is an integer.

09:37.560 --> 09:39.660
It is zero, one or two.

09:39.660 --> 09:44.070
So what we're gonna take now is our library torch.

09:44.070 --> 09:46.140
Then, we're gonna take the long,

09:46.140 --> 09:49.170
here it is, the long tensor class.

09:49.170 --> 09:51.810
That will create an object,

09:51.810 --> 09:53.640
which will be the long tensor itself.

09:53.640 --> 09:57.060
And by taking this self.lastaction function as input,

09:57.060 --> 10:00.150
it will create this long tensor object,

10:00.150 --> 10:02.820
but it will still contain zero, one or two

10:02.820 --> 10:04.830
into a long tensor object.

10:04.830 --> 10:08.820
And that is just to be consistent with the transition

10:08.820 --> 10:10.860
that should only contain tensors

10:10.860 --> 10:12.390
because we're working with PyTorch

10:12.390 --> 10:14.130
and we're working with a neural network,

10:14.130 --> 10:16.140
so we have to work with tensors.

10:16.140 --> 10:18.570
So there we go, torch, long tensor

10:18.570 --> 10:20.820
and one last conversion to make.

10:20.820 --> 10:24.570
We must be sure that what's inside this long tensor

10:24.570 --> 10:25.830
is an integer.

10:25.830 --> 10:28.080
And to make sure of it, even if we already know

10:28.080 --> 10:30.150
that the last action is zero, one or two,

10:30.150 --> 10:30.983
to make sure of it,

10:30.983 --> 10:34.830
we're gonna make this int type conversion again.

10:34.830 --> 10:39.270
We convert our self.lastaction into an integer.

10:39.270 --> 10:40.103
There we go.

10:40.103 --> 10:44.010
And then we must just put the integer self.lastaction

10:44.010 --> 10:46.140
into brackets, right here,

10:46.140 --> 10:49.350
so that now we get a long tensor of one element,

10:49.350 --> 10:52.800
which will be this last action zero, one or two itself.

10:52.800 --> 10:55.560
So the key point is that's just how you convert

10:55.560 --> 11:00.030
a simple number, zero, one or two into a tensor with torch.

11:00.030 --> 11:02.160
All right, and then, finally,

11:02.160 --> 11:04.320
the last element of the transition.

11:04.320 --> 11:06.780
And that's, of course, the last reward we got.

11:06.780 --> 11:09.900
So that's exactly the last reward variable we created

11:09.900 --> 11:13.110
in the init function that was initialized to zero,

11:13.110 --> 11:17.220
but then, of course, is updated right here in this code,

11:17.220 --> 11:18.990
either when we go into some sand,

11:18.990 --> 11:20.490
which is a negative reward,

11:20.490 --> 11:22.320
or if we get further away from the goal,

11:22.320 --> 11:24.150
that's again, a negative reward.

11:24.150 --> 11:27.180
If we get closer to the goal, that's a positive reward.

11:27.180 --> 11:28.980
And the worst punishment,

11:28.980 --> 11:31.260
if we get too close to one edge of the map,

11:31.260 --> 11:34.260
well, that's a terrible negative reward, minus one,

11:34.260 --> 11:35.253
and that's all.

11:36.210 --> 11:39.390
So let's add this last element of the transition,

11:39.390 --> 11:41.010
self.reward.

11:41.010 --> 11:42.870
So I'm copying this.

11:42.870 --> 11:44.040
Pasting it here.

11:44.040 --> 11:46.800
And now, we have to make another conversion,

11:46.800 --> 11:49.740
which will be, of course, exactly the same as this one.

11:49.740 --> 11:52.320
Only, since the reward is not an integer,

11:52.320 --> 11:53.760
but a float number,

11:53.760 --> 11:57.510
we will simply make a torch.tensor conversion,

11:57.510 --> 11:58.680
but without the int.

11:58.680 --> 12:00.630
We will keep the brackets here because,

12:00.630 --> 12:03.360
first, we have to put the number into a list

12:03.360 --> 12:05.400
and then this list will go as input

12:05.400 --> 12:06.990
in the torch tensor class.

12:06.990 --> 12:09.090
But we don't have to make that int conversion

12:09.090 --> 12:11.490
because last reward is a float number.

12:11.490 --> 12:16.490
So what we're gonna do is simply add here torch.tensor.

12:16.800 --> 12:21.660
Torch.tensor, then parenthesis, brackets,

12:21.660 --> 12:25.500
and we're gonna close the brackets here,

12:25.500 --> 12:27.450
and we close the parenthesis.

12:27.450 --> 12:28.830
There we go.

12:28.830 --> 12:31.080
So to summarize, with the new state

12:31.080 --> 12:33.150
that we just reached and the reward,

12:33.150 --> 12:35.520
we observe a new event of transition

12:35.520 --> 12:37.500
that we add to the memory.

12:37.500 --> 12:40.890
And this transition contains the last state, ST,

12:40.890 --> 12:44.493
the new state, ST + 1, the last action played, AT,

12:45.510 --> 12:47.970
and the last reward, RT.

12:47.970 --> 12:51.390
All right, and now, we are good with our memory update.

12:51.390 --> 12:52.590
So let's have a quick break

12:52.590 --> 12:54.450
and we will take care of the next update

12:54.450 --> 12:55.950
in the next tutorial.

12:55.950 --> 12:57.543
Until then, enjoy AI.
