You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
5.8 KiB
Matlab
190 lines
5.8 KiB
Matlab
% 2023-01-15
|
|
|
|
% based on the script for task one, this is rcnnscript for direct
|
|
% classification
|
|
|
|
% Requirements:
|
|
% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip
|
|
% in script location
|
|
% (this script unzips data and renames two files, but there is unlabeled
|
|
% image-material, that has to be removed by hand after unzipping)
|
|
|
|
% additional scriptfiles:
|
|
% - func_setupData_stepthree.m
|
|
% unpack data etc.
|
|
% - func_groundTruthFromLabelPic_stepthree.m
|
|
% (generate groundtruthtablle from image-data)
|
|
% - augmentData_stepthree.m
|
|
% (dataaugmentation für RCNN)
|
|
% - helperSanitizeBoxes.m
|
|
% (part of augmentation)
|
|
% - preprocessData.m
|
|
% (Resize image and bounding boxes to targetSize.
|
|
%
|
|
|
|
% required add-on(s):
|
|
% - 'Deep Learning Toolbox Model for ResNet-50 Network'
|
|
% - 'image processing toolbox'
|
|
% - 'Computer Vision Toolbox '
|
|
|
|
% recommended add-on(s) - if gpu is apt for the job....
|
|
% - 'Parallel Computing Toolbox'
|
|
|
|
% adjustable parameters
|
|
% - if there is no trained net, it can be trained with this script:
|
|
% set doTraining to true
|
|
|
|
|
|
close all;
|
|
clear;
|
|
|
|
doTraining = true;
|
|
|
|
% first we need the data...
|
|
dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data
|
|
zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA
|
|
grDataFile = 'signDatasetGroundTruth_stepthree.mat';
|
|
func_setupData_stepthree(dataDir, zippedDataFile, grDataFile);
|
|
|
|
%load data
|
|
data = load(grDataFile);
|
|
traficSignDataset = data.DataSet;
|
|
|
|
% ----- split the dataset into training, validation, and test sets.
|
|
% Select 60% of the data for training, 10% for validation, and the
|
|
% rest for testing the trained detector
|
|
|
|
rng(0)
|
|
shuffledIndices = randperm(height(traficSignDataset));
|
|
idx = floor(0.6 * height(traficSignDataset));
|
|
|
|
trainingIdx = 1:idx;
|
|
trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:);
|
|
|
|
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
|
|
validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:);
|
|
|
|
testIdx = validationIdx(end)+1 : length(shuffledIndices);
|
|
testDataTbl = traficSignDataset(shuffledIndices(testIdx),:);
|
|
|
|
% ----- use imageDatastore and boxLabelDatastore to create datastores
|
|
% for loading the image and label data during training and evaluation.
|
|
|
|
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
|
|
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end))
|
|
|
|
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
|
|
bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end))
|
|
|
|
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
|
|
bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end))
|
|
|
|
% combine image and box label datastores.
|
|
|
|
trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore
|
|
validationData = combine(imdsValidation,bldsValidation);
|
|
testData = combine(imdsTest,bldsTest);
|
|
|
|
% display one of the training images and box labels.
|
|
|
|
data = read(trainingData);
|
|
I = data{1};
|
|
bbox = data{2};
|
|
annotatedImage = insertShape(I,'Rectangle',bbox);
|
|
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
|
|
figure
|
|
imshow(annotatedImage)
|
|
|
|
% ----- Create Faster R-CNN Detection Network
|
|
|
|
inputSize = [224 224 3];
|
|
|
|
preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
|
|
% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet
|
|
|
|
|
|
% Auswahl der anchor boxes
|
|
% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html
|
|
numAnchors = 3;
|
|
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors);
|
|
|
|
% und das feature CNN
|
|
featureExtractionNetwork = resnet50;
|
|
featureLayer = 'activation_40_relu';
|
|
numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden
|
|
|
|
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
|
|
|
|
% Netzwerk ansehen
|
|
% analyzeNetwork(lgraph)
|
|
augmentedTrainingData = transform(trainingData,@augmentData_stepthree);
|
|
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
|
|
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
|
|
|
|
|
|
options = trainingOptions('sgdm',...
|
|
'MaxEpochs',10,...
|
|
'MiniBatchSize',2,...
|
|
'InitialLearnRate',1e-3,...
|
|
'CheckpointPath',tempdir,...
|
|
'ValidationData',validationData);
|
|
|
|
netname = "netDetectorResNet50_stepthree.mat";
|
|
|
|
|
|
|
|
if doTraining
|
|
% Train the Faster R-CNN detector.
|
|
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
|
|
% that training samples tightly overlap with ground truth.
|
|
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
|
|
'NegativeOverlapRange',[0 0.3], ...
|
|
'PositiveOverlapRange',[0.6 1]);
|
|
save netname detector;
|
|
else
|
|
% Load pretrained detector for the example.
|
|
load (netname, 'detector');
|
|
end
|
|
|
|
% ----- quick check/test
|
|
|
|
I = imresize(I,inputSize(1:2));
|
|
[bboxes,scores] = detect(detector,I);
|
|
% Display the results.
|
|
|
|
sfigTitle = "";
|
|
if height(bboxes) > 0
|
|
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
|
|
sfigTitle = "Detected";
|
|
else
|
|
sfigTitle = "Not Detected";
|
|
end
|
|
|
|
|
|
figure;
|
|
imshow(I);
|
|
annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle)
|
|
|
|
% ----- Testing
|
|
|
|
testData = transform(testData,@(data)preprocessData(data,inputSize));
|
|
|
|
% Run the detector on all the test images.
|
|
|
|
detectionResults = detect(detector,testData,'MinibatchSize',4);
|
|
|
|
% Evaluate the object detector using the average precision metric.
|
|
|
|
[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData);
|
|
% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.
|
|
|
|
figure
|
|
plot(recall,precision)
|
|
xlabel('Recall')
|
|
ylabel('Precision')
|
|
grid on
|
|
title(sprintf('Average Precision = %.2f', ap))
|
|
|
|
|
|
|