Flatten and unflatten a neural network

18 views (last 30 days)
I've been working on optimizing a neural network. I cannot use the built in routines per se since the the ANN is embedded in a constrained optimization. It would be nice to have a pair of functions to take the ANN parameters (weights, biases) and flatten it into a vector, then to take a vector and unflatten it into a NN. I've done this manually myself for an ANN with 1 hidden layer. But I want to explore a deeper NN, and it'd be nice to have this function pair.

Accepted Answer

Greg Heath
Greg Heath on 2 Jul 2017
help getwb
doc getwb
help setwb
doc setwb
Hope this helps.
Thank you for formally accepting my answer
Greg

More Answers (1)

Walter Roberson
Walter Roberson on 1 Jul 2017
flatten = @(thiscell) cell2mat( cellfun(@(M) M(:), thiscell, 'uniform', 0) );
TO unflatten you would mat2cell() and then reshape() to the sizes required. That would require information about the desired sizes be passed in.
sizevec = @(Sizes) cellfun(@(thiscell) prod(thiscell), Sizes(:).')
unflatten = @(thisvector, Sizes) cellfun(@(thiscell, thissize) reshape(thiscell, thissize), mat2cell(thisvector, 1, sizevec(Sizes)), Sizes);
The above is not tested.
Input to flatten is a cell array of numeric arrays.
Input to unflatten is a numeric vector and a cell array of sizes to break the vector up into.
  2 Comments
Charles Haas
Charles Haas on 1 Jul 2017
Thats a start. Really, I'd ideally like this sort of pair: flatvec=flatten(ANN) unflat=unflatten(flatvec,ANNtemplate)
Where ANNtemplate has the structure of the desired Neural Network object to be returned by unflatten. These functions would need to determine the vector sizes via the Neural Network object itself.
Walter Roberson
Walter Roberson on 1 Jul 2017
Well you could certainly enhance the code I posted to do that. It would probably be easiest to write a pair of real functions rather than trying to pack everything in to anonymous functions, which tends to compromise readability.
Consider constructing a cell array of strings of field names to extract from the net object. An arrayfun later you can have extracted the required content into a cell array that could then be flattened; likewise you could pass the existing net object in to unflatten to run through and extract the sizes for reconstruction, and then loop through and assign the various properties once the cell of reconstructed data had been created.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!