You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
HA_DIGSIG/RCNN_for_traficsigns_stepth...

229 lines
7.0 KiB
Matlab

% 2023-01-15
% based on the script for task one, this is rcnnscript for direct
% classification
% Requirements:
% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip
% in script location
% (this script unzips data and renames two files, but there is unlabeled
% image-material, that has to be removed by hand after unzipping)
% additional scriptfiles:
% - func_setupData_stepthree.m
% unpack data etc.
% - func_groundTruthFromLabelPic_stepthree.m
% (generate groundtruthtablle from image-data)
% - augmentData_stepthree.m
% (dataaugmentation für RCNN)
% - helperSanitizeBoxes.m
% (part of augmentation)
% - preprocessData.m
% (Resize image and bounding boxes to targetSize.
%
% required add-on(s):
% - 'Deep Learning Toolbox Model for ResNet-50 Network'
% - 'image processing toolbox'
% - 'Computer Vision Toolbox '
% recommended add-on(s) - if gpu is apt for the job....
% - 'Parallel Computing Toolbox'
% adjustable parameters
% - if there is no trained net, it can be trained with this script:
% set doTraining to true
% - the training can use augmentation or not:
% set doAugmentation accordingly
close all;
clear;
doTraining = true;
% first we need the data...
dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data
zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA
grDataFile = 'signDatasetGroundTruth_stepthree.mat';
func_setupData_stepthree(dataDir, zippedDataFile, grDataFile);
%load data
data = load(grDataFile);
traficSignDataset = data.DataSet;
% ----- split the dataset into training, validation, and test sets.
% Select 60% of the data for training, 10% for validation, and the
% rest for testing the trained detector
rng(0)
shuffledIndices = randperm(height(traficSignDataset));
idx = floor(0.6 * height(traficSignDataset));
trainingIdx = 1:idx;
trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = traficSignDataset(shuffledIndices(testIdx),:);
% ----- use imageDatastore and boxLabelDatastore to create datastores
% for loading the image and label data during training and evaluation.
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end))
% combine image and box label datastores.
trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
% display one of the training images and box labels.
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
figure
imshow(annotatedImage)
% ----- Create Faster R-CNN Detection Network
inputSize = [224 224 3];
preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet
% display one of the training images and box labels.
while 1==0 %hasdata(preprocessedTrainingData)
data = read(preprocessedTrainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
figure(1)
imshow(annotatedImage)
pause(0.100)
end
% Auswahl der anchor boxes
% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors);
% und das feature CNN
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
% Netzwerk ansehen
% analyzeNetwork(lgraph)
augmentedTrainingData = transform(trainingData,@augmentData_stepthree);
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
netname = "netDetectorResNet50_stepthree.mat"
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
save netname detector;
else
% Load pretrained detector for the example.
load netname detector;
end
% ----- quick check/test
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
% Display the results.
sfigTitle = ""
if height(bboxes) > 0
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
sfigTitle = "Detected";
else
sfigTitle = "Not Detected";
end
figure;
imshow(I);
annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle)
% ----- Testing
testData = transform(testData,@(data)preprocessData(data,inputSize));
% Run the detector on all the test images.
detectionResults = detect(detector,testData,'MinibatchSize',4);
% Evaluate the object detector using the average precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData);
% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))
% ----- Helper functions
% function data = augmentData(data)
% % Randomly flip images and bounding boxes horizontally.
% tform = randomAffine2d('XReflection',true);
% sz = size(data{1});
% rout = affineOutputView(sz,tform);
% data{1} = imwarp(data{1},tform,'OutputView',rout);
%
% % Sanitize box data, if needed.
% data{2} = helperSanitizeBoxes(data{2}, sz);
%
% % Warp boxes.
% data{2} = bboxwarp(data{2},tform,rout);
% end
%
% function data = preprocessData(data,targetSize)
% % Resize image and bounding boxes to targetSize.
% sz = size(data{1},[1 2]);
% scale = targetSize(1:2)./sz;
% data{1} = imresize(data{1},targetSize(1:2));
%
% % Sanitize box data, if needed.
% data{2} = helperSanitizeBoxes(data{2}, sz);
%
% % Resize boxes.
% data{2} = bboxresize(data{2},scale);
% end