Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
ИДЗ_Лебедь.docx
Скачиваний:
3
Добавлен:
04.09.2023
Размер:
225.78 Кб
Скачать

Приложение г

Код для выполнения работы с классификатором по критерию Фишера:

fisher_data.step_1 = StepData( ... [FZH.train; ZHT.train], NR.train, ... [FZH.test; ZHT.test; NR.test], ... [zeros(30, 1) + 0; zeros(15, 1) + 1], ... {'ЖТ+ФЖ', 'НР'}, ... 1); fisher_data.step_2 = StepData( ... ZHT.train, FZH.train, ... [ZHT.test; FZH.test], ... [zeros(15, 1) + 0; zeros(15, 1) + 1], ... {'ЖТ', 'ФЖ'}, ... 2); fc = FisherClassifier(fisher_data.step_1, fisher_data.step_2); [results.fisher, plots.fisher] = fc.classify(); f = figure; f.Name = "'ЖТ+ФЖ', 'НР'"; tiledlayout(1, 2, 'TileSpacing', 'none', 'Padding', 'none'); nexttile; hold on; histogram(plots.fisher.step_1.h1.Y, round(plots.fisher.step_1.h1.nbins)); histogram(plots.fisher.step_1.h2.Y, round(plots.fisher.step_1.h2.nbins)); legend({'ЖТ+ФЖ', 'НР'}); title('Этап 1'); nexttile; hold on; histogram(plots.fisher.step_2.h1.Y, round(plots.fisher.step_2.h1.nbins)); histogram(plots.fisher.step_2.h2.Y, round(plots.fisher.step_2.h2.nbins)); legend({'ЖТ', 'ФЖ'}); title('Этап 2'); saveas(gcf, 'fisher hists.png'); f = figure; f.Name = "'ЖТ+ФЖ', 'НР'"; tiledlayout(1, 2, 'TileSpacing', 'none', 'Padding', 'none'); nexttile; hold on; plot(plots.fisher.step_1.d_x, plots.fisher.step_1.d1.Y); plot(plots.fisher.step_1.d_x, plots.fisher.step_1.d2.Y); xline(results.fisher.step_1.tr_plot.tr, '--k'); legend({'ЖТ+ФЖ', 'НР', 'Порог'}); title('Этап 1'); nexttile; hold on; plot(plots.fisher.step_2.d_x, plots.fisher.step_2.d1.Y); plot(plots.fisher.step_2.d_x, plots.fisher.step_2.d2.Y); xline(results.fisher.step_2.tr_plot.tr, '--k'); legend({'ЖТ', 'ФЖ', 'Порог'}); title('Этап 2');

classdef FisherClassifier properties W1, A1 W2, A2 sd1, sd2 end methods function obj = FisherClassifier(step_1_data, step_2_data) obj.sd1 = step_1_data; obj.W1 = FisherClassifier.get_W(... obj.sd1.Xtrain_1, ... obj.sd1.Xtrain_2); obj.sd2 = step_2_data; obj.W2 = FisherClassifier.get_W(... obj.sd2.Xtrain_1, ... obj.sd2.Xtrain_2); end function [results, plots] = classify(obj) obj.sd1 = obj.sd1.calc_projections(obj.W1, false); obj.sd2 = obj.sd2.calc_projections(obj.W2, false); [results.step_1.tr_plot.tr, step1_d1, step1_d2, step1_x] = obj.sd1.calc_normal_dist(); results.step_1.W = obj.W1; plots.step_1.d1.Y = pdf(step1_d1, step1_x); plots.step_1.d2.Y = pdf(step1_d2, step1_x); plots.step_1.d_x = step1_x; plots.step_1.h1.Y = obj.sd1.proj1; plots.step_1.h1.nbins = length(obj.sd1.proj1) / 2; plots.step_1.h2.Y = obj.sd1.proj2; plots.step_1.h2.nbins = length(obj.sd1.proj2) / 2; results.step_1.mu1 = step1_d1.mu; results.step_1.sigma1 = step1_d1.sigma; results.step_1.mu2 = step1_d2.mu; results.step_1.sigma2 = step1_d2.sigma; results.step_1.labels{1} = unicode2native(obj.sd1.labels{1},'UTF-8'); results.step_1.labels{2} = unicode2native(obj.sd1.labels{2},'UTF-8'); [results.step_2.tr_plot.tr, step2_d1, step2_d2, step2_x] = obj.sd2.calc_normal_dist(); results.step_2.W = obj.W2; plots.step_2.d1.Y = pdf(step2_d1, step2_x); plots.step_2.d2.Y = pdf(step2_d2, step2_x); plots.step_2.d_x = step2_x; plots.step_2.h1.Y = obj.sd2.proj1; plots.step_2.h1.nbins = length(obj.sd2.proj1) / 2; plots.step_2.h2.Y = obj.sd2.proj2; plots.step_2.h2.nbins = length(obj.sd2.proj2) / 2; results.step_2.mu1 = step2_d1.mu; results.step_2.sigma1 = step2_d1.sigma; results.step_2.mu2 = step2_d2.mu; results.step_2.sigma2 = step2_d2.sigma; results.step_2.labels{1} = unicode2native(obj.sd2.labels{1},'UTF-8'); results.step_2.labels{2} = unicode2native(obj.sd2.labels{2},'UTF-8'); % =========== ACC, SPEC, SENS, TPR, FPR =========== proj = [obj.sd1.proj1; obj.sd1.proj2]; results.step_1.tr_plot = FisherClassifier.add_classifier_characteristics(proj, ... results.step_1.tr_plot, obj.sd1.Ytest); proj = [obj.sd2.proj1; obj.sd2.proj2]; results.step_2.tr_plot = FisherClassifier.add_classifier_characteristics(proj, ... results.step_2.tr_plot, obj.sd2.Ytest); % =========== ROC =========== plots.step_1.roc.sample = calc_roc('sample', [obj.sd1.proj1; obj.sd1.proj2], obj.sd1.Ytest); plots.step_1.roc.gauss = calc_roc('gauss', step1_x, step1_d1, step1_d2); plots.step_2.roc.sample = calc_roc('sample', [obj.sd2.proj1; obj.sd2.proj2], obj.sd2.Ytest); plots.step_2.roc.gauss = calc_roc('gauss', step2_x, step2_d1, step2_d2); results.step_1.roc.sample.T = plots.step_1.roc.sample.T; results.step_1.roc.sample.AUC = plots.step_1.roc.sample.AUC; results.step_1.roc.gauss.T = plots.step_1.roc.gauss.T; results.step_1.roc.gauss.AUC = plots.step_1.roc.gauss.AUC; results.step_2.roc.sample.T = plots.step_2.roc.sample.T; results.step_2.roc.sample.AUC = plots.step_2.roc.sample.AUC; results.step_2.roc.gauss.T = plots.step_2.roc.gauss.T; results.step_2.roc.gauss.AUC = plots.step_2.roc.gauss.AUC; end end methods (Static) function W = get_W(X1, X2) E1 = cov(X1); E2 = cov(X2); E = E1 + E2; M1 = mean(X1); M2 = mean(X2); W = inv(E) * (M1 - M2)'; W = W ./ norm(W, 2); end function tr_data = add_classifier_characteristics(proj, tr_data, Ytest) Ypred(proj < tr_data.tr, 1) = 1; Ypred(proj > tr_data.tr, 1) = 0; [tr_data.acc, ... tr_data.spec, ... tr_data.sens, ... tr_data.tpr, ... tr_data.fpr] = classifier_characteristics(Ypred, Ytest); end end end

classdef StepData properties Xtrain_1, Xtrain_2 Xtest Ytest labels step_num proj1, proj2 end methods function obj = StepData(Xtrain_1, Xtrain_2, Xtest, Ytest, labels, step_num) obj.Xtrain_1 = Xtrain_1; obj.Xtrain_2 = Xtrain_2; obj.Xtest = Xtest; obj.Ytest = Ytest; obj.labels = labels; obj.step_num = step_num; end function obj = calc_projections(obj, W, add_probs) obj.proj1 = obj.Xtrain_1 * W; obj.proj2 = obj.Xtrain_2 * W; if add_probs Sigma1 = var(obj.Xtrain_1); Sigma2 = var(obj.Xtrain_2); Sigma = mean([Sigma1, Sigma2]); assert(isscalar(Sigma)); L1 = size(obj.Xtrain_1, 1); L2 = size(obj.Xtrain_2, 1); P1 = L1 / (L1 + L2); P2 = L2 / (L1 + L2); P = P1 / P2; P = log(P) * Sigma ^ 2; obj.proj1 = obj.proj1 - P; obj.proj2 = obj.proj2 - P; end end function [tr_x, d1, d2, x] = calc_normal_dist(obj) x = min([obj.proj1; obj.proj2]) : 0.0001 : max([obj.proj1; obj.proj2]); d1 = fitdist(obj.proj1, 'Normal'); d2 = fitdist(obj.proj2, 'Normal'); tr_x = StepData.get_tr(x, d1, d2); end end methods (Static) function tr_x = get_tr(x, d1, d2) abs_diff = abs(pdf(d1, x) - pdf(d2, x)); [~, inds] = sort(abs_diff); mu_min = min([d1.mu, d2.mu]); mu_max = max([d1.mu, d2.mu]); for ind = inds p = x(ind); if mu_min <= p && p <= mu_max tr_x = p; if ind < length(x) tr_x = (p + x(ind + 1)) / 2; end return; end end assert(false); end end end

Соседние файлы в предмете Технологии и системы принятия решений