gan 코드 다른 데이터셋으로 훈련

6 views (last 30 days)
성목
성목 on 5 Feb 2024
Answered: Angelo Yeo on 17 Feb 2024
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'image2500');
digitData = imageDatastore(digitDatasetPath, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
% 데이터 증강 설정
augmenter = imageDataAugmenter('RandXReflection', true);
augimds = augmentedImageDatastore([128 128], digitData, 'DataAugmentation', augmenter)
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;
projectionSize = [4 4 512];
layersGenerator = [
featureInputLayer(numLatentInputs)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(filterSize,4*numFilters)
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
tanhLayer];
netG = dlnetwork(layersGenerator);
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;
inputSize = [64 64 3];
filterSize = 5;
layersDiscriminator = [
imageInputLayer(inputSize,Normalization="none")
dropoutLayer(dropoutProb)
convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
leakyReluLayer(scale)
convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(4,1)
sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
numEpochs = 50;
miniBatchSize = 128;
learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
augimds.MiniBatchSize = miniBatchSize;
mbq = minibatchqueue(augimds, ...
MiniBatchSize=miniBatchSize, ...
PartialMiniBatch="discard", ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPU
ZValidation = gpuArray(ZValidation);
end
numObservationsTrain = numel(imds.Files);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics=["GeneratorScore","DiscriminatorScore"], ...
Info=["Epoch","Iteration"], ...
XLabel="Iteration");
groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Reset and shuffle datastore.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Generate latent inputs for the generator network. Convert to
% dlarray and specify the format "CB" (channel, batch). If a GPU is
% available, then convert latent inputs to gpuArray.
Z = randn(numLatentInputs,miniBatchSize,"single");
Z = dlarray(Z,"CB");
if canUseGPU
Z = gpuArray(Z);
end
% Evaluate the gradients of the loss with respect to the learnable
% parameters, the generator state, and the network scores using
% dlfeval and the modelLoss function.
[~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
netG.State = stateG;
% Update the discriminator network parameters.
[netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
trailingAvg, trailingAvgSqD, iteration, ...
learnRate, gradientDecayFactor, squaredGradientDecayFactor);
% Update the generator network parameters.
[netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
trailingAvgG, trailingAvgSqG, iteration, ...
learnRate, gradientDecayFactor, squaredGradientDecayFactor);
% Every validationFrequency iterations, display batch of generated
% images using the held-out generator input.
if mod(iteration,validationFrequency) == 0 || iteration == 1
% Generate images using the held-out generator input.
XGeneratedValidation = predict(netG,ZValidation);
% Tile and rescale the images in the range [0 1].
I = imtile(extractdata(XGeneratedValidation));
I = rescale(I);
% Display the images.
image(I)
xticklabels([]);
yticklabels([]);
title("Generated Images");
end
% Update the training progress monitor.
recordMetrics(monitor,iteration, ...
GeneratorScore=scoreG, ...
DiscriminatorScore=scoreD);
updateInfo(monitor,Epoch=epoch,Iteration=iteration);
monitor.Progress = 100*iteration/numIterations;
end
end
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
modelLoss(netG,netD,X,Z,flipProb)
% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X);
% Calculate the predictions for generated data with the discriminator
% network.
[XGenerated,stateG] = forward(netG,Z);
YGenerated = forward(netD,XGenerated);
% Calculate the score of the discriminator.
scoreD = (mean(YReal) + mean(1-YGenerated)) / 2;
% Calculate the score of the generator.
scoreG = mean(YGenerated);
% Randomly flip the labels of the real images.
numObservations = size(YReal,4);
idx = rand(1,numObservations) < flipProb;
YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx);
% Calculate the GAN loss.
[lossG, lossD] = ganLoss(YReal,YGenerated);
% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);
end
function [lossG,lossD] = ganLoss(YReal,YGenerated)
% Calculate the loss for the discriminator network.
lossD = -mean(log(YReal)) - mean(log(1-YGenerated));
% Calculate the loss for the generator network.
lossG = -mean(log(YGenerated));
end
function X = preprocessMiniBatch(data)
% Concatenate mini-batch
X = cat(4,data{:});
% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);
end
크기가 128x128 , 2500개로 모인 이미지 파일 'image2500'를 GAN 예제 코드로 돌려볼려고 하는데 잘 안됩니다... fullfile 함수로 경로 지정해주고 128 크기에 맞게 수치값도 바꿔준 거 같은데 훈련이 돌아가지만 새 영상 만들기도 잘 안되는 거 같습니다. 뭐가 문제일까요??

Answers (1)

Angelo Yeo
Angelo Yeo on 17 Feb 2024
좋은 질문 감사합니다. 설명에 따르면 아마 제공되는 예제의 코드를 그대로 이용했을 것으로 생각됩니다.
GAN은 훈련이 어렵기로 유명합니다. 특히, 판별자와 생성자를 동시에 훈련해주어야 하는데 둘의 레벨을 적절히 맞춰주면서 훈련시켜주어야 합니다. 판별자와 생성자의 각각의 loss, 판별자의 정답 여부, 생성된 그림의 상태 등을 통해 현상을 확인할 수 있습니다. 몇 가지 일반적인 가이드라인을 드리자면 아래와 같습니다.
  1. 판별자가 생성자보다 너무 뛰어나면 그림 생성에 대한 훈련이 진행되지 않습니다. 그럴 때는 dropout rate를 조정하거나, 판별자의 learning rate을 줄이거나, 판별자의 필터 수를 줄이는 등의 방법으로 판별자를 약화시켜야 합니다.
  2. 또, 생성자가 판별자보다 너무 뛰어나면 단순한 몇 개의 패턴만으로 판별자를 속여버리게 되고 다양한 출력을 생성하지 않습니다. 이 또한, 훈련한 생성자의 생성 결과물을 직접 보고 판단해야 합니다.
  3. loss가 줄어든다고 꼭 좋은 생성 결과물을 내는 것은 아닙니다. GAN의 loss는 훈련 받고 있는 시점의 판별자로부터 나오는 결과물이고 판별자는 뉴럴넷이 훈련되어 가면서 계속해서 개선되기 때문입니다. 그래서 "훈련이 돌아가지만 새 영상 만들기도 잘 안되는 거 같습니다." 라고 말씀하셨지만, loss가 계속 떨어진다는 사실만으로 GAN이 정상적으로 작동한다는 보장은 전혀 없습니다.
  4. 이 외의 하이퍼 파라미터들도 조금씩 조절해가면서 훈련이 잘 되는지 확인해야 합니다. GAN은 하이퍼파라미터를 약간만 바꿔도 훈련이 안될 수 있습니다.

Tags

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!