diff --git a/RCNN_for_traficsigns_stepthree.m b/RCNN_for_traficsigns_stepthree.m new file mode 100644 index 0000000..ae521fc --- /dev/null +++ b/RCNN_for_traficsigns_stepthree.m @@ -0,0 +1,228 @@ +% 2023-01-15 + +% based on the script for task one, this is rcnnscript for direct +% classification + +% Requirements: +% data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip +% in script location +% (this script unzips data and renames two files, but there is unlabeled +% image-material, that has to be removed by hand after unzipping) + +% additional scriptfiles: +% - func_setupData.m +% unpack data etc. +% - func_groundTruthFromLabelPic.m +% (generate groundtruthtablle from image-data) +% - augmentData.m +% (dataaugmentation für RCNN) +% - helperSanitizeBoxes.m +% (part of augmentation) +% - preprocessData.m +% (Resize image and bounding boxes to targetSize. +% + +% required add-on(s): +% - 'Deep Learning Toolbox Model for ResNet-50 Network' +% - 'image processing toolbox' +% - 'Computer Vision Toolbox ' + +% recommended add-on(s) - if gpu is apt for the job.... +% - 'Parallel Computing Toolbox' + +% adjustable parameters +% - if there is no trained net, it can be trained with this script: +% set doTraining to true +% - the training can use augmentation or not: +% set doAugmentation accordingly + +close all; +clear; + +doTraining = true; + +% first we need the data... +dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data +zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA +grDataFile = 'signDatasetGroundTruth_stepthree.mat'; +func_setupData_stepthree(dataDir, zippedDataFile, grDataFile); + +%load data +data = load(grDataFile); +traficSignDataset = data.DataSet; + +% ----- split the dataset into training, validation, and test sets. +% Select 60% of the data for training, 10% for validation, and the +% rest for testing the trained detector + +rng(0) +shuffledIndices = randperm(height(traficSignDataset)); +idx = floor(0.6 * height(traficSignDataset)); + +trainingIdx = 1:idx; +trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:); + +validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) ); +validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:); + +testIdx = validationIdx(end)+1 : length(shuffledIndices); +testDataTbl = traficSignDataset(shuffledIndices(testIdx),:); + +% ----- use imageDatastore and boxLabelDatastore to create datastores +% for loading the image and label data during training and evaluation. + +imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'}); +bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end)) + +imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'}); +bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end)) + +imdsTest = imageDatastore(testDataTbl{:,'imageFilename'}); +bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end)) + +% combine image and box label datastores. + +trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore +validationData = combine(imdsValidation,bldsValidation); +testData = combine(imdsTest,bldsTest); + +% display one of the training images and box labels. + +data = read(trainingData); +I = data{1}; +bbox = data{2}; +annotatedImage = insertShape(I,'Rectangle',bbox); +annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung +figure +imshow(annotatedImage) + +% ----- Create Faster R-CNN Detection Network + +inputSize = [224 224 3]; + +preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize)); +% Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet + +% display one of the training images and box labels. +while 1==0 %hasdata(preprocessedTrainingData) + data = read(preprocessedTrainingData); + I = data{1}; + bbox = data{2}; + annotatedImage = insertShape(I,'Rectangle',bbox); + annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung + figure(1) + imshow(annotatedImage) + pause(0.100) +end + +% Auswahl der anchor boxes +% Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html +numAnchors = 3; +anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors); + +% und das feature CNN +featureExtractionNetwork = resnet50; +featureLayer = 'activation_40_relu'; +numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden + +lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer); + +% Netzwerk ansehen +% analyzeNetwork(lgraph) +augmentedTrainingData = transform(trainingData,@augmentData_stepthree); +trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize)); +validationData = transform(validationData,@(data)preprocessData(data,inputSize)); + + +options = trainingOptions('sgdm',... + 'MaxEpochs',10,... + 'MiniBatchSize',2,... + 'InitialLearnRate',1e-3,... + 'CheckpointPath',tempdir,... + 'ValidationData',validationData); + +netname = "netDetectorResNet50_stepthree.mat" + + + +if doTraining + % Train the Faster R-CNN detector. + % * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure + % that training samples tightly overlap with ground truth. + [detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ... + 'NegativeOverlapRange',[0 0.3], ... + 'PositiveOverlapRange',[0.6 1]); + save netname detector; +else + % Load pretrained detector for the example. + load netname detector; +end + +% ----- quick check/test + +I = imresize(I,inputSize(1:2)); +[bboxes,scores] = detect(detector,I); +% Display the results. + +sfigTitle = "" +if height(bboxes) > 0 + I = insertObjectAnnotation(I,'rectangle',bboxes,scores); + sfigTitle = "Detected"; +else + sfigTitle = "Not Detected"; +end + + +figure; +imshow(I); +annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle) + +% ----- Testing + +testData = transform(testData,@(data)preprocessData(data,inputSize)); + +% Run the detector on all the test images. + +detectionResults = detect(detector,testData,'MinibatchSize',4); + +% Evaluate the object detector using the average precision metric. + +[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData); +% The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve. + +figure +plot(recall,precision) +xlabel('Recall') +ylabel('Precision') +grid on +title(sprintf('Average Precision = %.2f', ap)) + + +% ----- Helper functions + +% function data = augmentData(data) +% % Randomly flip images and bounding boxes horizontally. +% tform = randomAffine2d('XReflection',true); +% sz = size(data{1}); +% rout = affineOutputView(sz,tform); +% data{1} = imwarp(data{1},tform,'OutputView',rout); +% +% % Sanitize box data, if needed. +% data{2} = helperSanitizeBoxes(data{2}, sz); +% +% % Warp boxes. +% data{2} = bboxwarp(data{2},tform,rout); +% end +% +% function data = preprocessData(data,targetSize) +% % Resize image and bounding boxes to targetSize. +% sz = size(data{1},[1 2]); +% scale = targetSize(1:2)./sz; +% data{1} = imresize(data{1},targetSize(1:2)); +% +% % Sanitize box data, if needed. +% data{2} = helperSanitizeBoxes(data{2}, sz); +% +% % Resize boxes. +% data{2} = bboxresize(data{2},scale); +% end diff --git a/augmentData_stepthree.m b/augmentData_stepthree.m new file mode 100644 index 0000000..fcfdd49 --- /dev/null +++ b/augmentData_stepthree.m @@ -0,0 +1,21 @@ +function data = augmentData_stepthree(data) +% Randomly flip images and bounding boxes horizontally. +tform = randomAffine2d('XReflection',true); +sz = size(data{1}); +rout = affineOutputView(sz,tform); + +% jitter +data{1} = jitterColorHSV(data{1},... + Contrast=0.2,... + Hue=0,... + Saturation=0.1,... + Brightness=0.2); + +data{1} = imwarp(data{1},tform,'OutputView',rout); + +% Sanitize box data, if needed. +data{2} = helperSanitizeBoxes(data{2}, sz); + +% Warp boxes. +data{2} = bboxwarp(data{2},tform,rout); +end \ No newline at end of file diff --git a/func_groundTruthFromLabelPic_stepthree.m b/func_groundTruthFromLabelPic_stepthree.m new file mode 100644 index 0000000..13bfbde --- /dev/null +++ b/func_groundTruthFromLabelPic_stepthree.m @@ -0,0 +1,135 @@ +function [ ] = func_groundTruthFromLabelPic( dataStorePicturePath, dataStoreLabelPath, outFile ) +% +% erstellt us einem Picture-Datastore und einem Label-Datastore +% eine Groundtruth-Tabelle, wie sie z.B. in FasterRCNN.m +% benoetigt wird +% +% file von tas +% adaptiert als func 2022/12/28 vh +% + +labelDS = imageDatastore(dataStoreLabelPath, 'IncludeSubfolders', true); +pictureDS = imageDatastore(dataStorePicturePath, 'IncludeSubfolders', true); + +labelCount = numel(labelDS.Files) +pictureCount = numel(pictureDS.Files) + +if labelCount ~= pictureCount + fprintf("!!!! Error: Die Anzahl der Bilder und Anzahl der LabelPicture sind ungleich -> Abbruch"); + return +end + +fprintf("-----------------------------------------------------------\n"); +fprintf("Picture-Verzeichnis: %s\n", dataStorePicturePath) +fprintf("Anzahl Images: %d\n", pictureCount) +fprintf("Label-Verzeichnis: %s\n", dataStoreLabelPath) +fprintf("Anzahl Images: %d\n", labelCount) +fprintf("BE PATIENT.... \n") +fprintf("-----------------------------------------------------------\n"); + +% table anlegen +sz = [pictureCount 6]; +varTypes = ["cellstr","cell","cell","cell","cell","logical"]; +varNames = ["imageFilename","30GBS","50GBS","60GBS","AndereGBS","valid"]; +DataSet = table('Size',sz,'VariableTypes',varTypes,'VariableNames',varNames); + +rng(0) +shuffledIndices = randperm(pictureCount); + +% los gehts +for i = 1:pictureCount + shuffeldIndex = shuffledIndices(i); + [imPic imPic_INFO]= readimage(pictureDS, shuffeldIndex); + [im_path imPic_name im_ext]=fileparts(imPic_INFO.Filename); + + [imLabel imLabel_INFO]= readimage(labelDS, shuffeldIndex); + [lbl_path imLabel_name im_ext]=fileparts(imLabel_INFO.Filename); + + %fprintf("picture: %s label: %s\n", imPic_name, imLabel_name); + %fprintf("picturepath: %s labelpath: %s\n", im_path, lbl_path); + box = [0,0,0,0]; %default if theres no labelimg + + v = true; + + if ~strcmp(imPic_name, imLabel_name) + fprintf("!!!! Error: zum Picture gibt es kein entsprechendes LabelPicture -> Abbruch"); + imPic_name + imLabel_name + + else + + % LabelRegion aus Image ausschneiden + bw = imLabel; + s = regionprops(bw, 'BoundingBox'); + box = cat(1, s.BoundingBox); % structure to matrix + box = round(box); + + % falls mehrere Marker vorhanden sind, ignorieren wir den Datensatz + % das dürfte für das Training Region detection besser sein ?? + if (height(box) > 1) + v = false; + end + box = box(1,:); + + if (numel(box) ~= 4) + fprintf("Boxkoordinaten nicht ok: %s %s\n", imPic_name, imLabel_name) + box + v = false + end + + end + + a = []; + b = []; + c = []; + d = []; + + + if contains(im_path,'30GBS') == 1 + a = num2cell(box, 2); + elseif contains(im_path,'50GBS') == 1 + b = num2cell(box, 2); + elseif contains(im_path,'60GBS') == 1 + c = num2cell(box, 2); + else + d = num2cell(box, 2); + end + + + %check for boxes which are somewhat wrong + if v + if (box(3) < 8.) ... + || (box(4) < 8.) ... + || (abs(box(3) - box(4)) > 2) ... + || (box(1) + box(3) > 1024) ... + || (box(2) + box(4) > 768) + fprintf("boxkoordinaten nicht i.o. %s (%d %d %d %d) \n", imLabel_name, box(1),box(2),box(3),box(4)); + v = false; + end + end + + DataSet(shuffeldIndex,:) = {imPic_INFO.Filename,a,b,c,d,v}; + + % display one of the training images and box labels. + if (shuffeldIndex == 4) + annotatedImage = insertShape(imPic,'Rectangle',box); + figure + imshow(annotatedImage) + end + +end + +%Die Daten sind teilweise nicht in Ordnung, am einfachsten ist natürlich +%die Einträge zu löschen, die nicht gut sind. +fprintf("Groundtruth hat %d eintraege vor der Bereinigung \n ", height(DataSet) ) + +toDelete = DataSet.valid == false; +DataSet(toDelete,:) = []; + +DataSet.valid=[]; + +fprintf("Groundtruth hat %d eintraege nach der Bereinigung \n", height(DataSet) ) + +save(outFile, 'DataSet' ); + + diff --git a/func_setupData_stepthree.m b/func_setupData_stepthree.m new file mode 100644 index 0000000..7c1c79f --- /dev/null +++ b/func_setupData_stepthree.m @@ -0,0 +1,43 @@ +function [ ] = func_setupData( dataDir, zippedDataFile, grDataFile ) + +% script func_setupData: +% - entpackt die trainingsdaten, +% - räumt ein bisschen auf, bzw. liefert Hinweise zum aufräumen +% - erstellt grounddata +% - (alles nur, falls es das nicht schon gibt) + +if not(exist(dataDir , 'dir')) + % unzip data + if (not(exist(zippedDataFile , 'file'))) + fprintf("Data file is missing please copy %s to script folder !", zippedDataFile); + return; + end + + fprintf("unzipping Data"); + unzip (zippedDataFile, dataDir) + + %rename files correctly (there are two wrong): + fprintf("fix faulty falenames\n"); + movefile (append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_Nr49.png'), append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_49.png')); + movefile (append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_Nr50.png'), append(dataDir, '/Labels_1024_768/60GBS/60GBS_Gruppe05_SS21_50.png')); + + fprintf("Labeldata is incomplete\n"); + fprintf("please compare data in keinGBS and delete what's to much manually\n"); + fprintf("otherwise grounddatageneration will fail\n"); + frpintf("then restart script \n"); + return; + +end + +% generate Grounddata from Pictures + +if (not(exist(grDataFile , 'file'))) + dataStorePicturePath = append(pwd,'/', dataDir,'/Pictures_1024_768/'); + dataStoreLabelPath = append(pwd,'/', dataDir, '/Labels_1024_768/'); + + + + % Die Tabelle wird in einer Funktion erstellt und gespeichert + % dabei werden Datensätze, wo die Label nicht passen entfernt. + func_groundTruthFromLabelPic_stepthree(dataStorePicturePath, dataStoreLabelPath, grDataFile); +end