Training a multitask network using custom training loops with dlnetwork()

18 views (last 30 days)
I want to train a multitask-network, i.e. a neural network that has one input and gives two outputs from separate output layers that are connected according to the image below. Each output should be associated with its own loss function, i.e. OUTPUT1 should be the input to LOSS1 and OUTPUT2 should be the input to LOSS2. The total loss function for backpropagation should be TOTAL_LOSS = LOSS1 + LOSS2. As can be noted in the image below, warnings are given for the final layers. This is because I have not specified the loss functions when creating these layers, and they are thus not output layers. The reason for not doing this is that the only example I have found for training multiple output networks with a combined loss function (https://se.mathworks.com/help/deeplearning/ug/train-network-with-multiple-outputs.html) does not use the dlnetwork()-function to create the network, but instead uses a model function, in which all operations are carried out. The problem is that my network is a modified U-Net, and in order to create a model function for this, I would have to implement all layer functions, as for example dropout, from scratch. This would be very tedious.
Is there any way to train a network with multiple outputs and combined loss functions as explained above using the dlnetwork()-function, without having to create a model function from scratch?
Note: I posed a similar, but less direct question a few months ago regarding how to train a multitask network in MATLAB, with no answer:
I then decided to move over to Keras and Python for this problem, with great results. Now I want to know if similar results can be obtained using MATLAB.
  1 Comment
Sven
Sven on 3 May 2023
Julius, did you ever succeed in what you were trying to do? We are in the same boat: able to create a network with two separate outputs in Python, would love to do the same thing in MATLAB but just hitting the same brick wall you have described.

Sign in to comment.

Answers (1)

Nomit Jangid
Nomit Jangid on 30 Nov 2020
Hi,
The documentation you are talking about is using dlnetwork function. I am not sure if the article has been updated since you last visited.
Have a look at the following article.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!