You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
HA_DIGSIG/RCNN_for_traficsigns_stepth...

190 lines
5.8 KiB
Matlab

2 years ago
% 2023-01-15
% based on the script for task one, this is rcnnscript for direct
% classification
% Requirements:
% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip
% in script location
% (this script unzips data and renames two files, but there is unlabeled
% image-material, that has to be removed by hand after unzipping)
% additional scriptfiles:
2 years ago
% - func_setupData_stepthree.m
2 years ago
% unpack data etc.
2 years ago
% - func_groundTruthFromLabelPic_stepthree.m
2 years ago
% (generate groundtruthtablle from image-data)
2 years ago
% - augmentData_stepthree.m
2 years ago
% (dataaugmentation für RCNN)
% - helperSanitizeBoxes.m
% (part of augmentation)
% - preprocessData.m
% (Resize image and bounding boxes to targetSize.
%
% required add-on(s):
% - 'Deep Learning Toolbox Model for ResNet-50 Network'
% - 'image processing toolbox'
% - 'Computer Vision Toolbox '
% recommended add-on(s) - if gpu is apt for the job....
% - 'Parallel Computing Toolbox'
% adjustable parameters
% - if there is no trained net, it can be trained with this script:
% set doTraining to true
2 years ago
close all;
clear;
doTraining = true;
2 years ago
% first we need the data...
dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data
zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA
grDataFile = 'signDatasetGroundTruth_stepthree.mat';
func_setupData_stepthree(dataDir, zippedDataFile, grDataFile);
%load data
data = load(grDataFile);
traficSignDataset = data.DataSet;
% ----- split the dataset into training, validation, and test sets.
% Select 60% of the data for training, 10% for validation, and the
% rest for testing the trained detector
rng(0)
shuffledIndices = randperm(height(traficSignDataset));
idx = floor(0.6 * height(traficSignDataset));
trainingIdx = 1:idx;
trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = traficSignDataset(shuffledIndices(testIdx),:);
% ----- use imageDatastore and boxLabelDatastore to create datastores
% for loading the image and label data during training and evaluation.
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end))
% combine image and box label datastores.
trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
% display one of the training images and box labels.
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
figure
imshow(annotatedImage)
% ----- Create Faster R-CNN Detection Network
inputSize = [224 224 3];
preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet
% Auswahl der anchor boxes
% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors);
% und das feature CNN
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
% Netzwerk ansehen
% analyzeNetwork(lgraph)
augmentedTrainingData = transform(trainingData,@augmentData_stepthree);
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
2 years ago
netname = "netDetectorResNet50_stepthree.mat";
2 years ago
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
save netname detector;
else
% Load pretrained detector for the example.
2 years ago
load (netname, 'detector');
2 years ago
end
% ----- quick check/test
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
% Display the results.
sfigTitle = "";
2 years ago
if height(bboxes) > 0
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
sfigTitle = "Detected";
else
sfigTitle = "Not Detected";
end
figure;
imshow(I);
annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle)
% ----- Testing
testData = transform(testData,@(data)preprocessData(data,inputSize));
% Run the detector on all the test images.
detectionResults = detect(detector,testData,'MinibatchSize',4);
% Evaluate the object detector using the average precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData);
% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))