% 2023-01-15
% based on the script for task one, this is rcnnscript for direct
% classification
% Requirements:
% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip
% in script location
% (this script unzips data and renames two files, but there is unlabeled
% image-material, that has to be removed by hand after unzipping)
% additional scriptfiles:
% - func_setupData_stepthree.m
% unpack data etc.
% - func_groundTruthFromLabelPic_stepthree.m
% (generate groundtruthtablle from image-data)
% - augmentData_stepthree.m
% (dataaugmentation für RCNN)
% - helperSanitizeBoxes.m
% (part of augmentation)
% - preprocessData.m
% (Resize image and bounding boxes to targetSize.
%
% required add-on(s):
% - 'Deep Learning Toolbox Model for ResNet-50 Network'
% - 'image processing toolbox'
% - 'Computer Vision Toolbox '
% recommended add-on(s) - if gpu is apt for the job....
% - 'Parallel Computing Toolbox'
% adjustable parameters
% - if there is no trained net, it can be trained with this script:
% set doTraining to true
close all ;
clear ;
doTraining = true ;
% first we need the data...
dataDir = ' Picturedata' ; % Destination-Folder for provided (img) Data
zippedDataFile = ' PicturesResizedLabelsResizedSignsCutted.zip' ; %Data provided by TA
grDataFile = ' signDatasetGroundTruth_stepthree.mat' ;
func_setupData_stepthree ( dataDir , zippedDataFile , grDataFile ) ;
%load data
data = load ( grDataFile ) ;
traficSignDataset = data . DataSet ;
% ----- split the dataset into training, validation, and test sets.
% Select 60% of the data for training, 10% for validation, and the
% rest for testing the trained detector
rng ( 0 )
shuffledIndices = randperm ( height ( traficSignDataset ) ) ;
idx = floor ( 0.6 * height ( traficSignDataset ) ) ;
trainingIdx = 1 : idx ;
trainingDataTbl = traficSignDataset ( shuffledIndices ( trainingIdx ) , : ) ;
validationIdx = idx + 1 : idx + 1 + floor ( 0.1 * length ( shuffledIndices ) ) ;
validationDataTbl = traficSignDataset ( shuffledIndices ( validationIdx ) , : ) ;
testIdx = validationIdx ( end ) + 1 : length ( shuffledIndices ) ;
testDataTbl = traficSignDataset ( shuffledIndices ( testIdx ) , : ) ;
% ----- use imageDatastore and boxLabelDatastore to create datastores
% for loading the image and label data during training and evaluation.
imdsTrain = imageDatastore ( trainingDataTbl { : , ' imageFilename' } ) ;
bldsTrain = boxLabelDatastore ( trainingDataTbl ( : , 2 : end ) )
imdsValidation = imageDatastore ( validationDataTbl { : , ' imageFilename' } ) ;
bldsValidation = boxLabelDatastore ( trainingDataTbl ( : , 2 : end ) )
imdsTest = imageDatastore ( testDataTbl { : , ' imageFilename' } ) ;
bldsTest = boxLabelDatastore ( trainingDataTbl ( : , 2 : end ) )
% combine image and box label datastores.
trainingData = combine ( imdsTrain , bldsTrain ) ; % erzeugt 'CombinedDatastore
validationData = combine ( imdsValidation , bldsValidation ) ;
testData = combine ( imdsTest , bldsTest ) ;
% display one of the training images and box labels.
data = read ( trainingData ) ;
I = data { 1 } ;
bbox = data { 2 } ;
annotatedImage = insertShape ( I , ' Rectangle' , bbox ) ;
annotatedImage = imresize ( annotatedImage , 4 ) ; % nur fuer Darstellung
figure
imshow ( annotatedImage )
% ----- Create Faster R-CNN Detection Network
inputSize = [ 224 224 3 ] ;
preprocessedTrainingData = transform ( trainingData , @ ( data ) preprocessData ( data , inputSize ) ) ;
% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet
% Auswahl der anchor boxes
% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html
numAnchors = 3 ;
anchorBoxes = estimateAnchorBoxes ( preprocessedTrainingData , numAnchors ) ;
% und das feature CNN
featureExtractionNetwork = resnet50 ;
featureLayer = ' activation_40_relu' ;
numClasses = width ( traficSignDataset ) - 1 ; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden
lgraph = fasterRCNNLayers ( inputSize , numClasses , anchorBoxes , featureExtractionNetwork , featureLayer ) ;
% Netzwerk ansehen
% analyzeNetwork(lgraph)
augmentedTrainingData = transform ( trainingData , @ augmentData_stepthree ) ;
trainingData = transform ( augmentedTrainingData , @ ( data ) preprocessData ( data , inputSize ) ) ;
validationData = transform ( validationData , @ ( data ) preprocessData ( data , inputSize ) ) ;
options = trainingOptions ( ' sgdm' , ...
' MaxEpochs' , 10 , ...
' MiniBatchSize' , 2 , ...
' InitialLearnRate' , 1e-3 , ...
' CheckpointPath' , tempdir , ...
' ValidationData' , validationData ) ;
netname = "netDetectorResNet50_stepthree . mat ";
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[ detector , info ] = trainFasterRCNNObjectDetector ( trainingData , lgraph , options , ...
' NegativeOverlapRange' , [ 0 0.3 ] , ...
' PositiveOverlapRange' , [ 0.6 1 ] ) ;
save netname detector ;
else
% Load pretrained detector for the example.
load ( netname , ' detector' ) ;
end
% ----- quick check/test
I = imresize ( I , inputSize ( 1 : 2 ) ) ;
[ bboxes , scores ] = detect ( detector , I ) ;
% Display the results.
sfigTitle = "";
if height ( bboxes ) > 0
I = insertObjectAnnotation ( I , ' rectangle' , bboxes , scores ) ;
sfigTitle = "Detected ";
else
sfigTitle = "Not Detected ";
end
figure ;
imshow ( I ) ;
annotation ( ' textbox' , [ 0.5 , 0.2 , 0.1 , 0.1 ] , ' String' , sfigTitle )
% ----- Testing
testData = transform ( testData , @ ( data ) preprocessData ( data , inputSize ) ) ;
% Run the detector on all the test images.
detectionResults = detect ( detector , testData , ' MinibatchSize' , 4 ) ;
% Evaluate the object detector using the average precision metric.
[ ap , recall , precision ] = evaluateDetectionPrecision ( detectionResults , testData ) ;
% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.
figure
plot ( recall , precision )
xlabel ( ' Recall' )
ylabel ( ' Precision' )
grid on
title ( sprintf ( ' Average Precision = %.2f' , ap ) )