Is it possible to use variable length arrays with SequenceInputLayer in a custom training loop, i.e. as a dlarray?

16 views (last 30 days)
I would like to train a LSTM network on sequence lengths of differing lengths. I have followed the example "nnet/SequenceClassificationUsing1DConvolutionsExample", but for my own application I need to implement such a LSTM network as a dlnetwork and use custom training loop. (It will be part of a generative model.)
% Edited example:
[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,TValidation] = japaneseVowelsTestData;
numFeatures = size(XTrain{1},1);
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;
filterSize = 3;
numFilters = 32;
layers = [ ...
sequenceInputLayer(numFeatures)
convolution1dLayer(filterSize,numFilters,Padding="causal")
reluLayer
layerNormalizationLayer
convolution1dLayer(filterSize,2*numFilters,Padding="causal")
reluLayer
layerNormalizationLayer
globalAveragePooling1dLayer
fullyConnectedLayer(numClasses)
softmaxLayer ]; % last layer removed for dlnetwork
% my code: create the dlnetwork
dlnetJap = dlnetwork( layers );
As I understand it from the documentation, dlarrays have to have fixed dimensions. This would not seem to be a problem as the example code stores the variable length arrays in a cell array. e.g.
net = trainNetwork(XTrain,TTrain,layers,options);
However, a dlnetwork does not permit a cell array as input. (I use analyzeNetwork to demonstrate this rather than showing my whole code. I get the same error when I use forward() inside the dfeval.)
analyzeNetwork( dlnetJap, XTrain );
Error using analyzeNetwork (line 56)
Invalid argument at position 2. Example network inputs must be formatted dlarray objects.
I can't dlarray create cell array as a dlarray.
analyzeNetwork( dlnetJap, dlarray( XTrain, 'CB' ) ); % or any other format
Error using dlarray (line 151)
dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
I also tried converting the array within each cell to a dlarray area but to no avail.
XTrainDl = cell( size(XTrain) );
for i = 1:length(XTrainDl)
XTrainDl{i} = dlarray( XTrain{i}, 'CB' );
end
analyzeNetwork( dlnetJap, XTrainDl );
Error using analyzeNetwork (line 56)
Invalid argument at position 2. Example network inputs must be formatted dlarray objects.
I can't see a way around this problem. I have already created the generative model based on fully connected layers rather than LSTM. I suppose I could use LSTM with a fixed length input, but my time series data differs in length. I would not like to time normalise to a standard length as that distorts the data. Nor can I pad it at one end because it comes from cyclical data. LSTM seems to be ideal for its ability to deal with sequences of variable lengths, but how to get that to work in a custom training loop?

Accepted Answer

Joss Knight
Joss Knight on 22 Jan 2022
dlnetwork objects do not take your entire dataset as input, they expect to receive a single batch at a time. You need to loop over your dataset and access the data. You can use iterator objects such as minibatchqueue and arrayDatastore to help you achieve this.
The purpose of dlnetwork is to give you much greater control over how to iterate over your dataset and flexibility over how your network is built, but as a result you need to write more of the code yourself.
dlnetworks are not restricted to fixed size inputs; they do sometimes need to be given example inputs in order to initialize but that doesn't mean that once initialized the input size cannot change. In your code, example input data is not needed for dlnetwork or analyzeNetwork, because your sequenceInputLayer already provides that information.
  3 Comments
Mark White
Mark White on 25 Jan 2022
Thank you. That makes sense. I'd overlooked the OutputType setting.
Having looked at the examples, I see that will have to pad the data, but sorting the series by length mitigates my concern.

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!