Teacher forcing for a LSTM network

Is there a way to implement teacher forcing for a LSTM network in MATLAB? Hopefully there is an option buried somewhere.

Answers (2)

Hi Philip,
This example shows how teacher forcing can be implemented with LSTM's in MATLAB.Sequence-to-Sequence Translation Using Attention.

1 Comment

hi David,
Thank you for your help but this is a attention model. I just need TF for a standard LSTM model.

Sign in to comment.

David Willingham
David Willingham on 24 May 2022
Edited: David Willingham on 24 May 2022
Please see the attached example, trainLSTM_seq2seq. Is this what you were looking for? I.e. an example for a standard LSTM model?

5 Comments

hi David,
Thank you for your help. If you look at the example for a plain vanilla lstm network
Once the network is defined, the options are specified
maxEpochs = 70;
miniBatchSize = 27;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress');
and training begins
net = trainNetwork(XTrain,YTrain,layers,options);
In your example, it looks as though one has to yank all the complexity out of the trainNetwork function and write one manually which is not ideal. Is it not possible for mathworks to include a training option to allow teacher forcing? I am trying to avoid writing low level ad-hoc code as much as possible.
Hi Philip,
The above example uses the custom training functionality in the deep learning toolbox. We currently don't have support for it as an option using trainNetwork.
hi David - i am looking at this in more details. Would open-loop foreasting https://uk.mathworks.com/help/deeplearning/ug/time-series-forecasting-using-deep-learning.html not the same as teacher forcing? could we not use this instead of training?
If your problem is a time-series problem then it possibly could. Have you tried adapting this example to meet your use case? If so, is it getting the result you expect?
hi David,
Thank you for your email. I just implemented the basic code without teacher forcing yesterday and during training it gives 30-40% accuracy. The author
seems to suggest that the model should perform much better. There are two things missing from my model currently:
1) i have not embedded the tokens (page 27) and
2) No teacher forcing
The problem with 1) is that these tokens are categorical tokens so after embedding (using the embed functions), I don't know how to retrieve the original tokens from the embedded data. I also presume that the data has to be discretized and there is no mention in the thesis of this so I am not sure what is going on TBH. Perhaps you can shed some light on this
2) i was going to try to use what you send to implement teacher forcing but I thought the open loop is a much neater way to implement the solution in this case. I am also dubious about TF - i suspect that all it's going to do is to overfit the data.

Sign in to comment.

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Products

Release

R2021a

Asked:

on 22 May 2022

Commented:

on 7 Jun 2022

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!