data = load("letisko_labels_new.mat");
LabelData = data.gTruth.LabelData;
LabelData.imageFilename = fullfile(LabelData.imageFilename);
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * length(shuffledIndices) );
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);
className = ["kamera","lietadlo","satelit","stlp","veza"];
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
options = trainingOptions("adam",...
    GradientDecayFactor=0.9,...
    SquaredGradientDecayFactor=0.999,...
    InitialLearnRate=0.001,...
    LearnRateSchedule="none",...
    L2Regularization=0.0005,...
    BatchNormalizationStatistics="moving",...
    DispatchInBackground=true,...
    ResetInputNormalization=false,...
    Shuffle="every-epoch",...
    ValidationFrequency=1000,...
    Plots="training-progress",...
    CheckpointPath='C:\BAKALARKA\checkpointYOLO',...
    ValidationData=validationData);
     [detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
    detector = downloadPretrainedYOLOv4Detector();
detectionResults = detect(detector,testData,'MiniBatchSize',4);
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
title(sprintf("Average Precision = %.2f",mean(ap)))