step three

main
v-h 2 years ago
parent 065f3d2bd1
commit eb15f88b6e

@ -0,0 +1,228 @@
% 2023-01-15
% based on the script for task one, this is rcnnscript for direct
% classification
% Requirements:
% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip
% in script location
% (this script unzips data and renames two files, but there is unlabeled
% image-material, that has to be removed by hand after unzipping)
% additional scriptfiles:
% - func_setupData.m
% unpack data etc.
% - func_groundTruthFromLabelPic.m
% (generate groundtruthtablle from image-data)
% - augmentData.m
% (dataaugmentation für RCNN)
% - helperSanitizeBoxes.m
% (part of augmentation)
% - preprocessData.m
% (Resize image and bounding boxes to targetSize.
%
% required add-on(s):
% - 'Deep Learning Toolbox Model for ResNet-50 Network'
% - 'image processing toolbox'
% - 'Computer Vision Toolbox '
% recommended add-on(s) - if gpu is apt for the job....
% - 'Parallel Computing Toolbox'
% adjustable parameters
% - if there is no trained net, it can be trained with this script:
% set doTraining to true
% - the training can use augmentation or not:
% set doAugmentation accordingly
close all;
clear;
doTraining = true;
% first we need the data...
dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data
zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA
grDataFile = 'signDatasetGroundTruth_stepthree.mat';
func_setupData_stepthree(dataDir, zippedDataFile, grDataFile);
%load data
data = load(grDataFile);
traficSignDataset = data.DataSet;
% ----- split the dataset into training, validation, and test sets.
% Select 60% of the data for training, 10% for validation, and the
% rest for testing the trained detector
rng(0)
shuffledIndices = randperm(height(traficSignDataset));
idx = floor(0.6 * height(traficSignDataset));
trainingIdx = 1:idx;
trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = traficSignDataset(shuffledIndices(testIdx),:);
% ----- use imageDatastore and boxLabelDatastore to create datastores
% for loading the image and label data during training and evaluation.
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end))
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end))
% combine image and box label datastores.
trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
% display one of the training images and box labels.
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
figure
imshow(annotatedImage)
% ----- Create Faster R-CNN Detection Network
inputSize = [224 224 3];
preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet
% display one of the training images and box labels.
while 1==0 %hasdata(preprocessedTrainingData)
data = read(preprocessedTrainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung
figure(1)
imshow(annotatedImage)
pause(0.100)
end
% Auswahl der anchor boxes
% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors);
% und das feature CNN
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
% Netzwerk ansehen
% analyzeNetwork(lgraph)
augmentedTrainingData = transform(trainingData,@augmentData_stepthree);
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
netname = "netDetectorResNet50_stepthree.mat"
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
save netname detector;
else
% Load pretrained detector for the example.
load netname detector;
end
% ----- quick check/test
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
% Display the results.
sfigTitle = ""
if height(bboxes) > 0
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
sfigTitle = "Detected";
else
sfigTitle = "Not Detected";
end
figure;
imshow(I);
annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle)
% ----- Testing
testData = transform(testData,@(data)preprocessData(data,inputSize));
% Run the detector on all the test images.
detectionResults = detect(detector,testData,'MinibatchSize',4);
% Evaluate the object detector using the average precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData);
% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))
% ----- Helper functions
% function data = augmentData(data)
% % Randomly flip images and bounding boxes horizontally.
% tform = randomAffine2d('XReflection',true);
% sz = size(data{1});
% rout = affineOutputView(sz,tform);
% data{1} = imwarp(data{1},tform,'OutputView',rout);
%
% % Sanitize box data, if needed.
% data{2} = helperSanitizeBoxes(data{2}, sz);
%
% % Warp boxes.
% data{2} = bboxwarp(data{2},tform,rout);
% end
%
% function data = preprocessData(data,targetSize)
% % Resize image and bounding boxes to targetSize.
% sz = size(data{1},[1 2]);
% scale = targetSize(1:2)./sz;
% data{1} = imresize(data{1},targetSize(1:2));
%
% % Sanitize box data, if needed.
% data{2} = helperSanitizeBoxes(data{2}, sz);
%
% % Resize boxes.
% data{2} = bboxresize(data{2},scale);
% end

@ -0,0 +1,21 @@
function data = augmentData_stepthree(data)
% Randomly flip images and bounding boxes horizontally.
tform = randomAffine2d('XReflection',true);
sz = size(data{1});
rout = affineOutputView(sz,tform);
% jitter
data{1} = jitterColorHSV(data{1},...
Contrast=0.2,...
Hue=0,...
Saturation=0.1,...
Brightness=0.2);
data{1} = imwarp(data{1},tform,'OutputView',rout);
% Sanitize box data, if needed.
data{2} = helperSanitizeBoxes(data{2}, sz);
% Warp boxes.
data{2} = bboxwarp(data{2},tform,rout);
end

@ -0,0 +1,135 @@
function [ ] = func_groundTruthFromLabelPic( dataStorePicturePath, dataStoreLabelPath, outFile )
%
% erstellt us einem Picture-Datastore und einem Label-Datastore
% eine Groundtruth-Tabelle, wie sie z.B. in FasterRCNN.m
% benoetigt wird
%
% file von tas
% adaptiert als func 2022/12/28 vh
%
labelDS = imageDatastore(dataStoreLabelPath, 'IncludeSubfolders', true);
pictureDS = imageDatastore(dataStorePicturePath, 'IncludeSubfolders', true);
labelCount = numel(labelDS.Files)
pictureCount = numel(pictureDS.Files)
if labelCount ~= pictureCount
fprintf("!!!! Error: Die Anzahl der Bilder und Anzahl der LabelPicture sind ungleich -> Abbruch");
return
end
fprintf("-----------------------------------------------------------\n");
fprintf("Picture-Verzeichnis: %s\n", dataStorePicturePath)
fprintf("Anzahl Images: %d\n", pictureCount)
fprintf("Label-Verzeichnis: %s\n", dataStoreLabelPath)
fprintf("Anzahl Images: %d\n", labelCount)
fprintf("BE PATIENT.... \n")
fprintf("-----------------------------------------------------------\n");
% table anlegen
sz = [pictureCount 6];
varTypes = ["cellstr","cell","cell","cell","cell","logical"];
varNames = ["imageFilename","30GBS","50GBS","60GBS","AndereGBS","valid"];
DataSet = table('Size',sz,'VariableTypes',varTypes,'VariableNames',varNames);
rng(0)
shuffledIndices = randperm(pictureCount);
% los gehts
for i = 1:pictureCount
shuffeldIndex = shuffledIndices(i);
[imPic imPic_INFO]= readimage(pictureDS, shuffeldIndex);
[im_path imPic_name im_ext]=fileparts(imPic_INFO.Filename);
[imLabel imLabel_INFO]= readimage(labelDS, shuffeldIndex);
[lbl_path imLabel_name im_ext]=fileparts(imLabel_INFO.Filename);
%fprintf("picture: %s label: %s\n", imPic_name, imLabel_name);
%fprintf("picturepath: %s labelpath: %s\n", im_path, lbl_path);
box = [0,0,0,0]; %default if theres no labelimg
v = true;
if ~strcmp(imPic_name, imLabel_name)
fprintf("!!!! Error: zum Picture gibt es kein entsprechendes LabelPicture -> Abbruch");
imPic_name
imLabel_name
else
% LabelRegion aus Image ausschneiden
bw = imLabel;
s = regionprops(bw, 'BoundingBox');
box = cat(1, s.BoundingBox); % structure to matrix
box = round(box);
% falls mehrere Marker vorhanden sind, ignorieren wir den Datensatz
% das dürfte für das Training Region detection besser sein ??
if (height(box) > 1)
v = false;
end
box = box(1,:);
if (numel(box) ~= 4)
fprintf("Boxkoordinaten nicht ok: %s %s\n", imPic_name, imLabel_name)
box
v = false
end
end
a = [];
b = [];
c = [];
d = [];
if contains(im_path,'30GBS') == 1
a = num2cell(box, 2);
elseif contains(im_path,'50GBS') == 1
b = num2cell(box, 2);
elseif contains(im_path,'60GBS') == 1
c = num2cell(box, 2);
else
d = num2cell(box, 2);
end
%check for boxes which are somewhat wrong
if v
if (box(3) < 8.) ...
|| (box(4) < 8.) ...
|| (abs(box(3) - box(4)) > 2) ...
|| (box(1) + box(3) > 1024) ...
|| (box(2) + box(4) > 768)
fprintf("boxkoordinaten nicht i.o. %s (%d %d %d %d) \n", imLabel_name, box(1),box(2),box(3),box(4));
v = false;
end
end
DataSet(shuffeldIndex,:) = {imPic_INFO.Filename,a,b,c,d,v};
% display one of the training images and box labels.
if (shuffeldIndex == 4)
annotatedImage = insertShape(imPic,'Rectangle',box);
figure
imshow(annotatedImage)
end
end
%Die Daten sind teilweise nicht in Ordnung, am einfachsten ist natürlich
%die Einträge zu löschen, die nicht gut sind.
fprintf("Groundtruth hat %d eintraege vor der Bereinigung \n ", height(DataSet) )
toDelete = DataSet.valid == false;
DataSet(toDelete,:) = [];
DataSet.valid=[];
fprintf("Groundtruth hat %d eintraege nach der Bereinigung \n", height(DataSet) )
save(outFile, 'DataSet' );

@ -0,0 +1,43 @@
function [ ] = func_setupData( dataDir, zippedDataFile, grDataFile )
% script func_setupData:
% - entpackt die trainingsdaten,
% - räumt ein bisschen auf, bzw. liefert Hinweise zum aufräumen
% - erstellt grounddata
% - (alles nur, falls es das nicht schon gibt)
if not(exist(dataDir , 'dir'))
% unzip data
if (not(exist(zippedDataFile , 'file')))
fprintf("Data file is missing please copy %s to script folder !", zippedDataFile);
return;
end
fprintf("unzipping Data");
unzip (zippedDataFile, dataDir)
%rename files correctly (there are two wrong):
fprintf("fix faulty falenames\n");
movefile (append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_Nr49.png'), append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_49.png'));
movefile (append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_Nr50.png'), append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_50.png'));
fprintf("Labeldata is incomplete\n");
fprintf("please compare data in keinGBS and delete what's to much manually\n");
fprintf("otherwise grounddatageneration will fail\n");
frpintf("then restart script \n");
return;
end
% generate Grounddata from Pictures
if (not(exist(grDataFile , 'file')))
dataStorePicturePath = append(pwd,'/', dataDir,'/Pictures_1024_768/');
dataStoreLabelPath = append(pwd,'/', dataDir, '/Labels_1024_768/');
% Die Tabelle wird in einer Funktion erstellt und gespeichert
% dabei werden Datensätze, wo die Label nicht passen entfernt.
func_groundTruthFromLabelPic_stepthree(dataStorePicturePath, dataStoreLabelPath, grDataFile);
end
Loading…
Cancel
Save