How do I create a custom layer with 2 inputs?

6 views (last 30 days)
Mark White
Mark White on 3 Feb 2022
Answered: MR on 26 Oct 2022
I have defined a custom layer endStateLayer to take two inputs, but when I try to create the dlnetwork I get an error saying the second input is not connected.
endStateLayer is designed to take the sequence output from a LSTM layer and the sequence lengths so the end state can be identified in each observation. (The sequences were padded at the end.) I have another custom layer, seqenceLengthLayer, which determines the sequence lengths.
This is my network:
[XTrain,YTrain] = japaneseVowelsTrainData;
numObservations = length(XTrain);
embedDim = size(XTrain{1},1);
latentDim = 9;
numHiddenUnits = 100;
filterSize = 3;
numFilters = 32;
% define encoder network
layersEnc = [
sequenceInputLayer( embedDim, 'Name', 'in' )
lstmLayer( numHiddenUnits, 'OutputMode', 'sequence' )
endStateLayer( 'Name', 'endstate' )
fullyConnectedLayer( latentDim ) ];
lgraphEnc = layerGraph( layersEnc );
lgraphEnc = addLayers( lgraphEnc, ...
sequenceLengthLayer( 0, 'Name', 'seqlen' ) ); % padding indicator is 0
lgraphEnc = connectLayers( lgraphEnc, ...
'in', 'seqlen' );
lgraphEnc = connectLayers( lgraphEnc, ...
'seqlen', 'endstate/len' );
dlnetEnc = dlnetwork( layersEnc );
It appears to be fully connected based on analyzeNetwork( lgraphEnc ):
The errors are to be expected when using lgraph for dlnetwork without an output.
However, I get the following error when I try to create the dlnetwork:
dlnetEnc = dlnetwork( layersEnc );
Error using dlnetwork/initialize (line 481)
Invalid network.
Error in dlnetwork (line 218)
net = initialize(net, dlX{:});
Error in lstmTest (line 26)
dlnetEnc = dlnetwork( layersEnc );
Caused by:
Example inputs: Incorrect number of example network inputs. 0 example network inputs provided but network has 2 inputs including 1 unconnected layer inputs.
Layer 'endstate': Unconnected input. Each input must be connected to input data or to the output of another layer.
Detected unconnected inputs:
input 'len'
This is the sequenceLengthLayer definition. I don't think the particulars with the predict function matter here, but I'm showing it for completeness. (I've not been able to check it fully because i can't setup the network.)
classdef sequenceLengthLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties
% (Optional) Layer properties.
PaddingIndicator
end
properties (Learnable)
% Layer learnable parameters.
end
methods
function layer = sequenceLengthLayer( padIndicator, NameValueArgs )
% layer = sequenceLengthLayer( padIndicator )
% creates an sequenceLengthLayer object that determines
% the length of the input sequence
% Parse input arguments.
arguments
padIndicator = 0;
NameValueArgs.Name = '';
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "Sequence length layer for padding " ...
+ join(string(padIndicator));
% Set layer type.
layer.Type = "Sequence Length";
% set the padding indicator.
layer.PaddingIndicator = padIndicator;
end
function L = predict( layer, X )
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Input data, specified as a formatted dlarray
% with a 'T' and 'C' dimension
% Outputs:
% Z - Output of layer forward function returned as
% a formatted dlarray with format 'CB'.
maxLength = size( X, 3 );
miniBatchSize = size(X, 2);
L = zeros( 1, miniBatchSize, 'like', X);
% find where X contains padding using the first channel
isPadding = (X(1,:,:)==layer.PaddingIndicator);
isPadding = logical( extractdata(isPadding) );
for i = 1:miniBatchSize
found = false;
padStart = 0;
while ~found && padStart<(maxLength-1)
% find where the padding begins
padStart = find( isPadding( padStart+1:end,i ), 1, 'first' );
% check if the paddings continues to the end
found = all(isPadding( padStart:end, i ));
end
if isempty( padStart )
% no padding found - go to the end
padStart = maxLength;
end
L(i) = padStart;
end
L = dlarray( L, 'CB' );
end
end
end
This is the endStateLayer where I think the problem lies. It is the one saying it does not get an second input. Is my construction function correct?
classdef endStateLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties
% (Optional) Layer properties.
end
properties (Learnable)
% Layer learnable parameters.
end
methods
function layer = endStateLayer( NameValueArgs )
% layer = endStateLayer( outputSize )
% creates an endStateLayer object that extracts the
% state of a sequence, X, a specified point, L
% Parse input arguments.
arguments
NameValueArgs.Name = '';
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "End state layer ";
% Set layer type.
layer.Type = "End State";
% set the inputs.
layer.NumInputs = 2;
layer.InputNames = { 'in', 'len' };
end
function Z = predict( layer, X, L )
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Input sequence data, specified as a
% formatted dlarray with a 'T' and 'C' dims
% L - Input sequence length data.
% Outputs:
% Z - Output of layer forward function returned as
% a formatted dlarray with format 'CB'.
miniBatchSize = size(X, 2);
Z = zeros( size(X,1), size(X,2), 'like', X);
for i = 1:miniBatchSize
Z(:,i) = X(:, i, L(i));
end
Z = dlarray(Z, 'CB');
end
end
end
I can't see what I have done wrong. I have reviewed the examples online. What have I missed?

Answers (1)

MR
MR on 26 Oct 2022
Hi Mark,
I have run your code on Matlab 2022b and I was able to reproduce your error message. Then I ran analyzeNetwork(lgraphEnc,"TargetUsage","dlnetwork") and I didn't get any error message. In a nutshell I think it might just be a typo. Write
dlnetEnc = dlnetwork(lgraphEnc); instead of dlnetEnc = dlnetwork( layersEnc );

Products


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!