how to create channel attention layer in matlab.
17 views (last 30 days)
Show older comments
classdef ChannelAttentionLayer < nnet.layer.Layer
% Reduction ratio used in the channel attention mechanism
properties (Learnable)
% Layer learnable parameters
function layer = ChannelAttentionLayer(reduction_ratio, input_channels, name)
% Constructor for ChannelAttentionLayer
layer.Name = name;
layer.ReductionRatio = reduction_ratio;
% Calculate reduced channels based on reduction ratio
reduced_channels = max(1, round(input_channels / reduction_ratio));
% Initialize weights and biases
layer.Weights1 = randn([1, 1, input_channels, reduced_channels], 'single');
layer.Bias1 = zeros([1, 1, reduced_channels], 'single');
layer.Weights2 = randn([1, 1, reduced_channels, input_channels], 'single');
layer.Bias2 = zeros([1, 1, input_channels], 'single');
function Z = forward(layer, X)
% Forward pass for training mode
% Ensure X is a dlarray
X = dlarray(X);
% Get input size
[H, W, C] = size(X);
% Global Average Pooling (GAP)
avg_pool = mean(X, [1, 2]); % Mean over height and width
avg_pool = reshape(avg_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% Global Max Pooling (GMP)
max_pool = max(X, [], [1, 2]); % Max over height and width
max_pool = reshape(max_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% First fully connected layer applied to both avg and max pooled outputs
avg_out = fullyconnect(avg_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
max_out = fullyconnect(max_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
% Apply ReLU
avg_out = relu(avg_out);
max_out = relu(max_out);
% Second fully connected layer
avg_out = fullyconnect(avg_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
max_out = fullyconnect(max_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
% Combine average and max pooled outputs
Z = avg_out + max_out;
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Reshape attention map and multiply with input
Z = reshape(Z, [1, 1, C]);
Z = X .* Z;
% Ensure Z is unformatted
Z = dlarray(Z);
function Z = predict(layer, X)
% Predict pass for inference mode
Z = forward(layer, X);
% Fully connected operation for 1x1 conv
function out = fullyconnect(input, weights, bias, input_channels, output_channels)
% Ensure the number of input channels matches the weights' channels
[H, W, C_in] = size(input);
[~, ~, C, ~] = size(weights);
if C_in ~= C
error('Number of channels in input and weights do not match.');
% Flatten input dimensions
input_reshaped = reshape(input, [], C_in); % Flatten spatial dimensions
% Perform matrix multiplication and add bias
weights_reshaped = reshape(weights, [C_in, output_channels]);
out = input_reshaped * weights_reshaped + reshape(bias, [1, output_channels]);
% Reshape back to original dimensions
out = reshape(out, [1, 1, output_channels]);
Answers (1)
See Also
Find more on Calculus in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!