% 2023-01-15 % based on the script for task one, this is rcnnscript for direct % classification % Requirements: % data has to be provided as PicturesResizedLabelsResizedSignsCutted.zip % in script location % (this script unzips data and renames two files, but there is unlabeled % image-material, that has to be removed by hand after unzipping) % additional scriptfiles: % - func_setupData_stepthree.m % unpack data etc. % - func_groundTruthFromLabelPic_stepthree.m % (generate groundtruthtablle from image-data) % - augmentData_stepthree.m % (dataaugmentation für RCNN) % - helperSanitizeBoxes.m % (part of augmentation) % - preprocessData.m % (Resize image and bounding boxes to targetSize. % % required add-on(s): % - 'Deep Learning Toolbox Model for ResNet-50 Network' % - 'image processing toolbox' % - 'Computer Vision Toolbox ' % recommended add-on(s) - if gpu is apt for the job.... % - 'Parallel Computing Toolbox' % adjustable parameters % - if there is no trained net, it can be trained with this script: % set doTraining to true % - the training can use augmentation or not: % set doAugmentation accordingly close all; clear; doTraining = true; % first we need the data... dataDir = 'Picturedata'; % Destination-Folder for provided (img) Data zippedDataFile = 'PicturesResizedLabelsResizedSignsCutted.zip'; %Data provided by TA grDataFile = 'signDatasetGroundTruth_stepthree.mat'; func_setupData_stepthree(dataDir, zippedDataFile, grDataFile); %load data data = load(grDataFile); traficSignDataset = data.DataSet; % ----- split the dataset into training, validation, and test sets. % Select 60% of the data for training, 10% for validation, and the % rest for testing the trained detector rng(0) shuffledIndices = randperm(height(traficSignDataset)); idx = floor(0.6 * height(traficSignDataset)); trainingIdx = 1:idx; trainingDataTbl = traficSignDataset(shuffledIndices(trainingIdx),:); validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) ); validationDataTbl = traficSignDataset(shuffledIndices(validationIdx),:); testIdx = validationIdx(end)+1 : length(shuffledIndices); testDataTbl = traficSignDataset(shuffledIndices(testIdx),:); % ----- use imageDatastore and boxLabelDatastore to create datastores % for loading the image and label data during training and evaluation. imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'}); bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:end)) imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'}); bldsValidation = boxLabelDatastore(trainingDataTbl(:,2:end)) imdsTest = imageDatastore(testDataTbl{:,'imageFilename'}); bldsTest = boxLabelDatastore(trainingDataTbl(:,2:end)) % combine image and box label datastores. trainingData = combine(imdsTrain,bldsTrain); % erzeugt 'CombinedDatastore validationData = combine(imdsValidation,bldsValidation); testData = combine(imdsTest,bldsTest); % display one of the training images and box labels. data = read(trainingData); I = data{1}; bbox = data{2}; annotatedImage = insertShape(I,'Rectangle',bbox); annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung figure imshow(annotatedImage) % ----- Create Faster R-CNN Detection Network inputSize = [224 224 3]; preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize)); % Achtung: dieser DS wird nur zur Ermittlung der BoundingBoxes verwendet % display one of the training images and box labels. while 1==0 %hasdata(preprocessedTrainingData) data = read(preprocessedTrainingData); I = data{1}; bbox = data{2}; annotatedImage = insertShape(I,'Rectangle',bbox); annotatedImage = imresize(annotatedImage,4); % nur fuer Darstellung figure(1) imshow(annotatedImage) pause(0.100) end % Auswahl der anchor boxes % Infos dazu: https://de.mathworks.com/help/vision/ug/estimate-anchor-boxes-from-training-data.html numAnchors = 3; anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors); % und das feature CNN featureExtractionNetwork = resnet50; featureLayer = 'activation_40_relu'; numClasses = width(traficSignDataset)-1; % also hier: 1, es sollen nur Verkehrsschilder erkannt werden lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer); % Netzwerk ansehen % analyzeNetwork(lgraph) augmentedTrainingData = transform(trainingData,@augmentData_stepthree); trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize)); validationData = transform(validationData,@(data)preprocessData(data,inputSize)); options = trainingOptions('sgdm',... 'MaxEpochs',10,... 'MiniBatchSize',2,... 'InitialLearnRate',1e-3,... 'CheckpointPath',tempdir,... 'ValidationData',validationData); netname = "netDetectorResNet50_stepthree.mat" if doTraining % Train the Faster R-CNN detector. % * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure % that training samples tightly overlap with ground truth. [detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ... 'NegativeOverlapRange',[0 0.3], ... 'PositiveOverlapRange',[0.6 1]); save netname detector; else % Load pretrained detector for the example. load netname detector; end % ----- quick check/test I = imresize(I,inputSize(1:2)); [bboxes,scores] = detect(detector,I); % Display the results. sfigTitle = "" if height(bboxes) > 0 I = insertObjectAnnotation(I,'rectangle',bboxes,scores); sfigTitle = "Detected"; else sfigTitle = "Not Detected"; end figure; imshow(I); annotation('textbox', [0.5, 0.2, 0.1, 0.1], 'String', sfigTitle) % ----- Testing testData = transform(testData,@(data)preprocessData(data,inputSize)); % Run the detector on all the test images. detectionResults = detect(detector,testData,'MinibatchSize',4); % Evaluate the object detector using the average precision metric. [ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData); % The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve. figure plot(recall,precision) xlabel('Recall') ylabel('Precision') grid on title(sprintf('Average Precision = %.2f', ap)) % ----- Helper functions % function data = augmentData(data) % % Randomly flip images and bounding boxes horizontally. % tform = randomAffine2d('XReflection',true); % sz = size(data{1}); % rout = affineOutputView(sz,tform); % data{1} = imwarp(data{1},tform,'OutputView',rout); % % % Sanitize box data, if needed. % data{2} = helperSanitizeBoxes(data{2}, sz); % % % Warp boxes. % data{2} = bboxwarp(data{2},tform,rout); % end % % function data = preprocessData(data,targetSize) % % Resize image and bounding boxes to targetSize. % sz = size(data{1},[1 2]); % scale = targetSize(1:2)./sz; % data{1} = imresize(data{1},targetSize(1:2)); % % % Sanitize box data, if needed. % data{2} = helperSanitizeBoxes(data{2}, sz); % % % Resize boxes. % data{2} = bboxresize(data{2},scale); % end