data = load("letisko_labels_new.mat");
LabelData = data.gTruth.LabelData;
LabelData.imageFilename = fullfile(LabelData.imageFilename);
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * length(shuffledIndices) );
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);
className = ["kamera","lietadlo","satelit","stlp","veza"];
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
options = trainingOptions("adam",...
GradientDecayFactor=0.9,...
SquaredGradientDecayFactor=0.999,...
InitialLearnRate=0.001,...
LearnRateSchedule="none",...
L2Regularization=0.0005,...
BatchNormalizationStatistics="moving",...
DispatchInBackground=true,...
ResetInputNormalization=false,...
Shuffle="every-epoch",...
ValidationFrequency=1000,...
Plots="training-progress",...
CheckpointPath='C:\BAKALARKA\checkpointYOLO',...
ValidationData=validationData);
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
detector = downloadPretrainedYOLOv4Detector();
detectionResults = detect(detector,testData,'MiniBatchSize',4);
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
title(sprintf("Average Precision = %.2f",mean(ap)))