A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) A(lambda)\f(lambda);
lambda = 0.123;
xlambda = x(lambda);
A(lambda)*xlambda - f(lambda)
net = [featureInputLayer(1)
fullyConnectedLayer(100)
reluLayer
fullyConnectedLayer(2)];
net = dlnetwork(net);
lambda = dlarray(linspace(-5,5,10000),"CB");
maxIters = 10000;
vel = [];
lr = 1e-4;
lossFcn = dlaccelerate(@modelLoss);
for iter = 1:maxIters
[loss,grad] = dlfeval(lossFcn,net,lambda);
fprintf("Iter: %d, Loss: %.4f\n",iter,extractdata(loss));
[net,vel] = sgdmupdate(net,grad,vel,lr);
end
function [loss,grad] = modelLoss(net,lambda)
x = forward(net,lambda);
x = stripdims(x);
x = permute(x,[1,3,2]);
lambda = stripdims(lambda);
lambda = permute(lambda,[1,3,2]);
A = [lambda.^2 + 1, lambda; lambda, ones(1,1,size(lambda,3),like=lambda)];
Ax = pagemtimes(A,x);
f = [lambda;1-lambda];
loss = l2loss(Ax,f,DataFormat="CUB");
grad = dlgradient(loss,net.Learnables);
end
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
function dxidlambda(lambda,i)
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
xlambda = x(lambda);
xlambdai = xlambda(i);
dxidlambda = dlgradient(xlambdai,lambda);
end
lambda0 = dlarray(0.123);
dx1dlambda = dlfeval(@dxidlambda, lambda0, 1)