1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package jr239.co620;
24
25
26 import weka.classifiers.CostMatrix;
27 import weka.classifiers.Classifier;
28 import weka.classifiers.Sourcable;
29 import weka.classifiers.UpdateableClassifier;
30
31
32 import weka.classifiers.evaluation.NominalPrediction;
33 import weka.classifiers.evaluation.ThresholdCurve;
34 import weka.classifiers.xml.XMLClassifier;
35 import weka.core.Drawable;
36 import weka.core.FastVector;
37 import weka.core.Instance;
38 import weka.core.Instances;
39 import weka.core.Option;
40 import weka.core.OptionHandler;
41 import weka.core.Range;
42 import weka.core.Summarizable;
43 import weka.core.Utils;
44 import weka.core.Version;
45 import weka.core.converters.ConverterUtils.DataSink;
46 import weka.core.converters.ConverterUtils.DataSource;
47 import weka.core.xml.KOML;
48 import weka.core.xml.XMLOptions;
49 import weka.core.xml.XMLSerialization;
50 import weka.estimators.Estimator;
51 import weka.estimators.KernelEstimator;
52
53 import java.io.BufferedInputStream;
54 import java.io.BufferedOutputStream;
55 import java.io.BufferedReader;
56 import java.io.FileInputStream;
57 import java.io.FileOutputStream;
58 import java.io.FileReader;
59 import java.io.InputStream;
60 import java.io.ObjectInputStream;
61 import java.io.ObjectOutputStream;
62 import java.io.OutputStream;
63 import java.io.Reader;
64 import java.util.Date;
65 import java.util.Enumeration;
66 import java.util.Random;
67 import java.util.zip.GZIPInputStream;
68 import java.util.zip.GZIPOutputStream;
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186 class EvaluationACO
187 implements Summarizable {
188
189
190 protected int m_NumClasses;
191
192
193 protected int m_NumFolds =10;
194
195
196 protected double m_Incorrect;
197
198
199 protected double m_Correct;
200
201
202 protected double m_Unclassified;
203
204
205 protected double m_MissingClass;
206
207
208 protected double m_WithClass;
209
210
211 protected double [][] m_ConfusionMatrix;
212
213
214 protected String [] m_ClassNames;
215
216
217 protected boolean m_ClassIsNominal;
218
219
220 protected double [] m_ClassPriors;
221
222
223 protected double m_ClassPriorsSum;
224
225
226 protected CostMatrix m_CostMatrix;
227
228
229 protected double m_TotalCost;
230
231
232 protected double m_SumErr;
233
234
235 protected double m_SumAbsErr;
236
237
238 protected double m_SumSqrErr;
239
240
241 protected double m_SumClass;
242
243
244 protected double m_SumSqrClass;
245
246
247 protected double m_SumPredicted;
248
249
250 protected double m_SumSqrPredicted;
251
252
253 protected double m_SumClassPredicted;
254
255
256 protected double m_SumPriorAbsErr;
257
258
259 protected double m_SumPriorSqrErr;
260
261
262 protected double m_SumKBInfo;
263
264
265 protected static int k_MarginResolution = 500;
266
267
268 protected double m_MarginCounts [];
269
270
271 protected int m_NumTrainClassVals;
272
273
274 protected double [] m_TrainClassVals;
275
276
277 protected double [] m_TrainClassWeights;
278
279
280 protected Estimator m_PriorErrorEstimator;
281
282
283 protected Estimator m_ErrorEstimator;
284
285
286
287
288
289 protected static final double MIN_SF_PROB = Double.MIN_VALUE;
290
291
292 protected double m_SumPriorEntropy;
293
294
295 protected double m_SumSchemeEntropy;
296
297
298 private FastVector m_Predictions;
299
300
301
302 protected boolean m_NoPriors = false;
303
304
305
306
307
308
309
310
311
312
313
314
315
316 public EvaluationACO(Instances data ) throws Exception {
317
318 this(data, null);
319
320
321 }
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338 public EvaluationACO(Instances data, CostMatrix costMatrix)
339 throws Exception {
340
341 m_NumClasses = data.numClasses();
342 m_NumFolds = 1;
343 m_ClassIsNominal = data.classAttribute().isNominal();
344
345 if (m_ClassIsNominal) {
346 m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
347 m_ClassNames = new String [m_NumClasses];
348 for(int i = 0; i < m_NumClasses; i++) {
349 m_ClassNames[i] = data.classAttribute().value(i);
350 }
351 }
352 m_CostMatrix = costMatrix;
353 if (m_CostMatrix != null) {
354 if (!m_ClassIsNominal) {
355 throw new Exception("Class has to be nominal if cost matrix " +
356 "given!");
357 }
358 if (m_CostMatrix.size() != m_NumClasses) {
359 throw new Exception("Cost matrix not compatible with data!");
360 }
361 }
362 m_ClassPriors = new double [m_NumClasses];
363 setPriors(data);
364 m_MarginCounts = new double [k_MarginResolution + 1];
365 }
366
367
368
369
370
371
372
373
374
375 public double areaUnderROC(int classIndex) {
376
377
378 if (m_Predictions == null) {
379 return Instance.missingValue();
380 } else {
381 ThresholdCurve tc = new ThresholdCurve();
382 Instances result = tc.getCurve(m_Predictions, classIndex);
383 return ThresholdCurve.getROCArea(result);
384 }
385 }
386
387
388
389
390
391
392 public double[][] confusionMatrix() {
393
394 double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
395
396 for (int i = 0; i < m_ConfusionMatrix.length; i++) {
397 newMatrix[i] = new double[m_ConfusionMatrix[i].length];
398 System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
399 m_ConfusionMatrix[i].length);
400 }
401 return newMatrix;
402 }
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419 public void crossValidateModel(Classifier classifier,
420 Instances data, int numFolds, Random random)
421 throws Exception {
422
423
424 data = new Instances(data);
425 data.randomize(random);
426 if (data.classAttribute().isNominal()) {
427 data.stratify(numFolds);
428 }
429
430 for (int i = 0; i < numFolds; i++) {
431 Instances train = data.trainCV(numFolds, i, random);
432 setPriors(train);
433 Classifier copiedClassifier = Classifier.makeCopy(classifier);
434 copiedClassifier.buildClassifier(train);
435 Instances test = data.testCV(numFolds, i);
436 evaluateModel(copiedClassifier, test);
437 }
438 m_NumFolds = numFolds;
439 }
440
441
442
443 public void
444 validateACOmodel(Classifier bc , Instances trainigData,
445 Instances testingData) throws Exception {
446
447 setPriors(trainigData);
448
449 bc.buildClassifier(trainigData);
450
451 evaluateModel(bc, testingData);
452
453
454 }
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473 public void crossValidateModel(String classifierString,
474 Instances data, int numFolds,
475 String[] options, Random random)
476 throws Exception {
477
478 crossValidateModel(Classifier.forName(classifierString, options),
479 data, numFolds, random);
480 }
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573 public static String evaluateModel(String classifierString,
574 String [] options) throws Exception {
575
576 Classifier classifier;
577
578
579 try {
580 classifier =
581 (Classifier)Class.forName(classifierString).newInstance();
582 } catch (Exception e) {
583 throw new Exception("Can't find class with name "
584 + classifierString + '.');
585 }
586 return evaluateModel(classifier, options);
587 }
588
589
590
591
592
593
594
595 public static void main(String [] args) {
596
597 try {
598 if (args.length == 0) {
599 throw new Exception("The first argument must be the class name"
600 + " of a classifier");
601 }
602 String classifier = args[0];
603 args[0] = "";
604 System.out.println(evaluateModel(classifier, args));
605 } catch (Exception ex) {
606 ex.printStackTrace();
607 System.err.println(ex.getMessage());
608 }
609 }
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693 public static String evaluateModel(Classifier classifier,
694 String [] options) throws Exception {
695
696 Instances train = null, tempTrain, test = null, template = null;
697 int seed = 1, folds = 10, classIndex = -1;
698 boolean noCrossValidation = false;
699 String trainFileName, testFileName, sourceClass,
700 classIndexString, seedString, foldsString, objectInputFileName,
701 objectOutputFileName, attributeRangeString;
702 boolean noOutput = false,
703 printClassifications = false, trainStatistics = true,
704 printMargins = false, printComplexityStatistics = false,
705 printGraph = false, classStatistics = false, printSource = false;
706 StringBuffer text = new StringBuffer();
707 DataSource trainSource = null, testSource = null;
708 ObjectInputStream objectInputStream = null;
709 BufferedInputStream xmlInputStream = null;
710 CostMatrix costMatrix = null;
711 StringBuffer schemeOptionsText = null;
712 Range attributesToOutput = null;
713 long trainTimeStart = 0, trainTimeElapsed = 0,
714 testTimeStart = 0, testTimeElapsed = 0;
715 String xml = "";
716 String[] optionsTmp = null;
717 Classifier classifierBackup;
718 Classifier classifierClassifications = null;
719 boolean printDistribution = false;
720 int actualClassIndex = -1;
721 String splitPercentageString = "";
722 int splitPercentage = -1;
723 boolean preserveOrder = false;
724 boolean trainSetPresent = false;
725 boolean testSetPresent = false;
726 String thresholdFile;
727 String thresholdLabel;
728
729
730 if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
731 throw new Exception("\nHelp requested." + makeOptionString(classifier));
732 }
733
734 try {
735
736 xml = Utils.getOption("xml", options);
737 if (!xml.equals(""))
738 options = new XMLOptions(xml).toArray();
739
740
741 optionsTmp = new String[options.length];
742 for (int i = 0; i < options.length; i++)
743 optionsTmp[i] = options[i];
744
745 if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
746
747 XMLClassifier xmlserial = new XMLClassifier();
748 Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
749
750 optionsTmp = new String[options.length + cl.getOptions().length];
751 System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
752 System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
753 options = optionsTmp;
754 }
755
756 noCrossValidation = Utils.getFlag("no-cv", options);
757
758 classIndexString = Utils.getOption('c', options);
759 if (classIndexString.length() != 0) {
760 if (classIndexString.equals("first"))
761 classIndex = 1;
762 else if (classIndexString.equals("last"))
763 classIndex = -1;
764 else
765 classIndex = Integer.parseInt(classIndexString);
766 }
767 trainFileName = Utils.getOption('t', options);
768 objectInputFileName = Utils.getOption('l', options);
769 objectOutputFileName = Utils.getOption('d', options);
770 testFileName = Utils.getOption('T', options);
771 foldsString = Utils.getOption('x', options);
772 if (foldsString.length() != 0) {
773 folds = Integer.parseInt(foldsString);
774 }
775 seedString = Utils.getOption('s', options);
776 if (seedString.length() != 0) {
777 seed = Integer.parseInt(seedString);
778 }
779 if (trainFileName.length() == 0) {
780 if (objectInputFileName.length() == 0) {
781 throw new Exception("No training file and no object "+
782 "input file given.");
783 }
784 if (testFileName.length() == 0) {
785 throw new Exception("No training file and no test "+
786 "file given.");
787 }
788 } else if ((objectInputFileName.length() != 0) &&
789 ((!(classifier instanceof UpdateableClassifier)) ||
790 (testFileName.length() == 0))) {
791 throw new Exception("Classifier not incremental, or no " +
792 "test file provided: can't "+
793 "use both train and model file.");
794 }
795 try {
796 if (trainFileName.length() != 0) {
797 trainSetPresent = true;
798 trainSource = new DataSource(trainFileName);
799 }
800 if (testFileName.length() != 0) {
801 testSetPresent = true;
802 testSource = new DataSource(testFileName);
803 }
804 if (objectInputFileName.length() != 0) {
805 InputStream is = new FileInputStream(objectInputFileName);
806 if (objectInputFileName.endsWith(".gz")) {
807 is = new GZIPInputStream(is);
808 }
809
810 if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
811 objectInputStream = new ObjectInputStream(is);
812 xmlInputStream = null;
813 }
814 else {
815 objectInputStream = null;
816 xmlInputStream = new BufferedInputStream(is);
817 }
818 }
819 } catch (Exception e) {
820 throw new Exception("Can't open file " + e.getMessage() + '.');
821 }
822 if (testSetPresent) {
823 template = test = testSource.getStructure();
824 if (classIndex != -1) {
825 test.setClassIndex(classIndex - 1);
826 } else {
827 if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
828 test.setClassIndex(test.numAttributes() - 1);
829 }
830 actualClassIndex = test.classIndex();
831 }
832 else {
833
834 splitPercentageString = Utils.getOption("split-percentage", options);
835 if (splitPercentageString.length() != 0) {
836 if (foldsString.length() != 0)
837 throw new Exception(
838 "Percentage split cannot be used in conjunction with "
839 + "cross-validation ('-x').");
840 splitPercentage = Integer.parseInt(splitPercentageString);
841 if ((splitPercentage <= 0) || (splitPercentage >= 100))
842 throw new Exception("Percentage split value needs be >0 and <100.");
843 }
844 else {
845 splitPercentage = -1;
846 }
847 preserveOrder = Utils.getFlag("preserve-order", options);
848 if (preserveOrder) {
849 if (splitPercentage == -1)
850 throw new Exception("Percentage split ('-percentage-split') is missing.");
851 }
852
853 if (splitPercentage > 0) {
854 testSetPresent = true;
855 Instances tmpInst = trainSource.getDataSet(actualClassIndex);
856 if (!preserveOrder)
857 tmpInst.randomize(new Random(seed));
858 int trainSize = tmpInst.numInstances() * splitPercentage / 100;
859 int testSize = tmpInst.numInstances() - trainSize;
860 Instances trainInst = new Instances(tmpInst, 0, trainSize);
861 Instances testInst = new Instances(tmpInst, trainSize, testSize);
862 trainSource = new DataSource(trainInst);
863 testSource = new DataSource(testInst);
864 template = test = testSource.getStructure();
865 if (classIndex != -1) {
866 test.setClassIndex(classIndex - 1);
867 } else {
868 if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
869 test.setClassIndex(test.numAttributes() - 1);
870 }
871 actualClassIndex = test.classIndex();
872 }
873 }
874 if (trainSetPresent) {
875 template = train = trainSource.getStructure();
876 if (classIndex != -1) {
877 train.setClassIndex(classIndex - 1);
878 } else {
879 if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
880 train.setClassIndex(train.numAttributes() - 1);
881 }
882 actualClassIndex = train.classIndex();
883 if ((testSetPresent) && !test.equalHeaders(train)) {
884 throw new IllegalArgumentException("Train and test file not compatible!");
885 }
886 }
887 if (template == null) {
888 throw new Exception("No actual dataset provided to use as template");
889 }
890 costMatrix = handleCostOption(
891 Utils.getOption('m', options), template.numClasses());
892
893 classStatistics = Utils.getFlag('i', options);
894 noOutput = Utils.getFlag('o', options);
895 trainStatistics = !Utils.getFlag('v', options);
896 printComplexityStatistics = Utils.getFlag('k', options);
897 printMargins = Utils.getFlag('r', options);
898 printGraph = Utils.getFlag('g', options);
899 sourceClass = Utils.getOption('z', options);
900 printSource = (sourceClass.length() != 0);
901 printDistribution = Utils.getFlag("distribution", options);
902 thresholdFile = Utils.getOption("threshold-file", options);
903 thresholdLabel = Utils.getOption("threshold-label", options);
904
905
906 try {
907 attributeRangeString = Utils.getOption('p', options);
908 }
909 catch (Exception e) {
910 throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
911 "It now expects a parameter specifying a range of attributes " +
912 "to list with the predictions. Use '-p 0' for none.");
913 }
914 if (attributeRangeString.length() != 0) {
915 printClassifications = true;
916 if (!attributeRangeString.equals("0"))
917 attributesToOutput = new Range(attributeRangeString);
918 }
919
920 if (!printClassifications && printDistribution)
921 throw new Exception("Cannot print distribution without '-p' option!");
922
923
924 if ( (!trainSetPresent) && (printComplexityStatistics) )
925 throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
926
927
928
929 if (objectInputFileName.length() != 0) {
930 Utils.checkForRemainingOptions(options);
931 } else {
932
933
934 if (classifier instanceof OptionHandler) {
935 for (int i = 0; i < options.length; i++) {
936 if (options[i].length() != 0) {
937 if (schemeOptionsText == null) {
938 schemeOptionsText = new StringBuffer();
939 }
940 if (options[i].indexOf(' ') != -1) {
941 schemeOptionsText.append('"' + options[i] + "\" ");
942 } else {
943 schemeOptionsText.append(options[i] + " ");
944 }
945 }
946 }
947 ((OptionHandler)classifier).setOptions(options);
948 }
949 }
950 Utils.checkForRemainingOptions(options);
951 } catch (Exception e) {
952 throw new Exception("\nWeka exception: " + e.getMessage()
953 + makeOptionString(classifier));
954 }
955
956
957 EvaluationACO trainingEvaluationACO = new EvaluationACO(new Instances(template, 0), costMatrix);
958 EvaluationACO testingEvaluationACO = new EvaluationACO(new Instances(template, 0), costMatrix);
959
960
961 if (!trainSetPresent)
962 testingEvaluationACO.useNoPriors();
963
964 if (objectInputFileName.length() != 0) {
965
966 if (objectInputStream != null) {
967 classifier = (Classifier) objectInputStream.readObject();
968
969 Instances savedStructure = null;
970 try {
971 savedStructure = (Instances) objectInputStream.readObject();
972 } catch (Exception ex) {
973
974 }
975 if (savedStructure != null) {
976
977 if (!template.equalHeaders(savedStructure)) {
978 throw new Exception("training and test set are not compatible");
979 }
980 }
981 objectInputStream.close();
982 }
983 else {
984
985 classifier = (Classifier) KOML.read(xmlInputStream);
986 xmlInputStream.close();
987 }
988 }
989
990
991 classifierBackup = Classifier.makeCopy(classifier);
992
993
994 if ((classifier instanceof UpdateableClassifier) &&
995 (testSetPresent) &&
996 (costMatrix == null) &&
997 (trainSetPresent)) {
998
999
1000 trainingEvaluationACO.setPriors(train);
1001 testingEvaluationACO.setPriors(train);
1002 trainTimeStart = System.currentTimeMillis();
1003 if (objectInputFileName.length() == 0) {
1004 classifier.buildClassifier(train);
1005 }
1006 Instance trainInst;
1007 while (trainSource.hasMoreElements(train)) {
1008 trainInst = trainSource.nextElement(train);
1009 trainingEvaluationACO.updatePriors(trainInst);
1010 testingEvaluationACO.updatePriors(trainInst);
1011 ((UpdateableClassifier)classifier).updateClassifier(trainInst);
1012 }
1013 trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
1014 } else if (objectInputFileName.length() == 0) {
1015
1016 tempTrain = trainSource.getDataSet(actualClassIndex);
1017 trainingEvaluationACO.setPriors(tempTrain);
1018 testingEvaluationACO.setPriors(tempTrain);
1019 trainTimeStart = System.currentTimeMillis();
1020 classifier.buildClassifier(tempTrain);
1021 trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
1022 }
1023
1024
1025 if (printClassifications)
1026 classifierClassifications = Classifier.makeCopy(classifier);
1027
1028
1029 if (objectOutputFileName.length() != 0) {
1030 OutputStream os = new FileOutputStream(objectOutputFileName);
1031
1032 if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
1033 if (objectOutputFileName.endsWith(".gz")) {
1034 os = new GZIPOutputStream(os);
1035 }
1036 ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
1037 objectOutputStream.writeObject(classifier);
1038 if (template != null) {
1039 objectOutputStream.writeObject(template);
1040 }
1041 objectOutputStream.flush();
1042 objectOutputStream.close();
1043 }
1044
1045 else {
1046 BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
1047 if (objectOutputFileName.endsWith(".xml")) {
1048 XMLSerialization xmlSerial = new XMLClassifier();
1049 xmlSerial.write(xmlOutputStream, classifier);
1050 }
1051 else
1052
1053
1054 if (objectOutputFileName.endsWith(".koml")) {
1055 KOML.write(xmlOutputStream, classifier);
1056 }
1057 xmlOutputStream.close();
1058 }
1059 }
1060
1061
1062 if ((classifier instanceof Drawable) && (printGraph)){
1063 return ((Drawable)classifier).graph();
1064 }
1065
1066
1067 if ((classifier instanceof Sourcable) && (printSource)){
1068 return wekaStaticWrapper((Sourcable) classifier, sourceClass);
1069 }
1070
1071
1072 if (!(noOutput || printMargins)) {
1073 if (classifier instanceof OptionHandler) {
1074 if (schemeOptionsText != null) {
1075 text.append("\nOptions: "+schemeOptionsText);
1076 text.append("\n");
1077 }
1078 }
1079 text.append("\n" + classifier.toString() + "\n");
1080 }
1081
1082 if (!printMargins && (costMatrix != null)) {
1083 text.append("\n=== EvaluationACO Cost Matrix ===\n\n");
1084 text.append(costMatrix.toString());
1085 }
1086
1087
1088 if (printClassifications) {
1089 DataSource source = testSource;
1090
1091 if (source == null)
1092 source = trainSource;
1093 return printClassifications(classifierClassifications, new Instances(template, 0),
1094 source, actualClassIndex + 1, attributesToOutput,
1095 printDistribution);
1096 }
1097
1098
1099 if ((trainStatistics) && (trainSetPresent)) {
1100
1101 if ((classifier instanceof UpdateableClassifier) &&
1102 (testSetPresent) &&
1103 (costMatrix == null)) {
1104
1105
1106
1107 trainSource.reset();
1108
1109
1110 train = trainSource.getStructure(actualClassIndex);
1111 testTimeStart = System.currentTimeMillis();
1112 Instance trainInst;
1113 while (trainSource.hasMoreElements(train)) {
1114 trainInst = trainSource.nextElement(train);
1115 trainingEvaluationACO.evaluateModelOnce((Classifier)classifier, trainInst);
1116 }
1117 testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1118 } else {
1119 testTimeStart = System.currentTimeMillis();
1120 trainingEvaluationACO.evaluateModel(
1121 classifier, trainSource.getDataSet(actualClassIndex));
1122 testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1123 }
1124
1125
1126 if (printMargins) {
1127 return trainingEvaluationACO.toCumulativeMarginDistributionString();
1128 } else {
1129 text.append("\nTime taken to build model: "
1130 + Utils.doubleToString(trainTimeElapsed / 1000.0,2)
1131 + " seconds");
1132
1133 if (splitPercentage > 0)
1134 text.append("\nTime taken to test model on training split: ");
1135 else
1136 text.append("\nTime taken to test model on training data: ");
1137 text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
1138
1139 if (splitPercentage > 0)
1140 text.append(trainingEvaluationACO.toSummaryString("\n\n=== Error on training"
1141 + " split ===\n", printComplexityStatistics));
1142 else
1143 text.append(trainingEvaluationACO.toSummaryString("\n\n=== Error on training"
1144 + " data ===\n", printComplexityStatistics));
1145
1146 if (template.classAttribute().isNominal()) {
1147 if (classStatistics) {
1148 text.append("\n\n" + trainingEvaluationACO.toClassDetailsString());
1149 }
1150 if (!noCrossValidation)
1151 text.append("\n\n" + trainingEvaluationACO.toMatrixString());
1152 }
1153
1154 }
1155 }
1156
1157
1158 if (testSource != null) {
1159
1160 Instance testInst;
1161 while (testSource.hasMoreElements(test)) {
1162 testInst = testSource.nextElement(test);
1163 testingEvaluationACO.evaluateModelOnceAndRecordPrediction(
1164 (Classifier)classifier, testInst);
1165 }
1166
1167 if (splitPercentage > 0)
1168 text.append("\n\n" + testingEvaluationACO.
1169 toSummaryString("=== Error on test split ===\n",
1170 printComplexityStatistics));
1171 else
1172 text.append("\n\n" + testingEvaluationACO.
1173 toSummaryString("=== Error on test data ===\n",
1174 printComplexityStatistics));
1175
1176 } else if (trainSource != null) {
1177 if (!noCrossValidation) {
1178
1179 Random random = new Random(seed);
1180
1181 classifier = Classifier.makeCopy(classifierBackup);
1182 testingEvaluationACO.crossValidateModel(
1183 classifier, trainSource.getDataSet(actualClassIndex), folds, random);
1184 if (template.classAttribute().isNumeric()) {
1185 text.append("\n\n\n" + testingEvaluationACO.
1186 toSummaryString("=== Cross-validation ===\n",
1187 printComplexityStatistics));
1188 } else {
1189 text.append("\n\n\n" + testingEvaluationACO.
1190 toSummaryString("=== Stratified " +
1191 "cross-validation ===\n",
1192 printComplexityStatistics));
1193 }
1194 }
1195 }
1196 if (template.classAttribute().isNominal()) {
1197 if (classStatistics) {
1198 text.append("\n\n" + testingEvaluationACO.toClassDetailsString());
1199 }
1200 if (!noCrossValidation)
1201 text.append("\n\n" + testingEvaluationACO.toMatrixString());
1202 }
1203
1204 if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
1205 int labelIndex = 0;
1206 if (thresholdLabel.length() != 0)
1207 labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
1208 if (labelIndex == -1)
1209 throw new IllegalArgumentException(
1210 "Class label '" + thresholdLabel + "' is unknown!");
1211 ThresholdCurve tc = new ThresholdCurve();
1212 Instances result = tc.getCurve(testingEvaluationACO.predictions(), labelIndex);
1213 DataSink.write(thresholdFile, result);
1214 }
1215
1216 return text.toString();
1217 }
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228 protected static CostMatrix handleCostOption(String costFileName,
1229 int numClasses)
1230 throws Exception {
1231
1232 if ((costFileName != null) && (costFileName.length() != 0)) {
1233 System.out.println(
1234 "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
1235 +" and WEKA 3.1. -m now carries out cost-sensitive *EvaluationACO*"
1236 +" only. For cost-sensitive *prediction*, use one of the"
1237 +" cost-sensitive metaschemes such as"
1238 +" weka.classifiers.meta.CostSensitiveClassifier or"
1239 +" weka.classifiers.meta.MetaCost");
1240
1241 Reader costReader = null;
1242 try {
1243 costReader = new BufferedReader(new FileReader(costFileName));
1244 } catch (Exception e) {
1245 throw new Exception("Can't open file " + e.getMessage() + '.');
1246 }
1247 try {
1248
1249 return new CostMatrix(costReader);
1250 } catch (Exception ex) {
1251 try {
1252
1253
1254 try {
1255 costReader.close();
1256 costReader = new BufferedReader(new FileReader(costFileName));
1257 } catch (Exception e) {
1258 throw new Exception("Can't open file " + e.getMessage() + '.');
1259 }
1260 CostMatrix costMatrix = new CostMatrix(numClasses);
1261
1262 costMatrix.readOldFormat(costReader);
1263 return costMatrix;
1264
1265 } catch (Exception e2) {
1266
1267
1268 throw ex;
1269 }
1270 }
1271 } else {
1272 return null;
1273 }
1274 }
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288 public double[] evaluateModel(Classifier classifier,
1289 Instances data) throws Exception {
1290
1291 double predictions[] = new double[data.numInstances()];
1292
1293
1294
1295 for (int i = 0; i < data.numInstances(); i++) {
1296 predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier,
1297 data.instance(i));
1298 }
1299
1300 return predictions;
1301 }
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313 public double evaluateModelOnceAndRecordPrediction(Classifier classifier,
1314 Instance instance) throws Exception {
1315
1316 Instance classMissing = (Instance)instance.copy();
1317 double pred = 0;
1318 classMissing.setDataset(instance.dataset());
1319 classMissing.setClassMissing();
1320 if (m_ClassIsNominal) {
1321 if (m_Predictions == null) {
1322 m_Predictions = new FastVector();
1323 }
1324 double [] dist = classifier.distributionForInstance(classMissing);
1325 pred = Utils.maxIndex(dist);
1326 if (dist[(int)pred] <= 0) {
1327 pred = Instance.missingValue();
1328 }
1329 updateStatsForClassifier(dist, instance);
1330 m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
1331 instance.weight()));
1332 } else {
1333 pred = classifier.classifyInstance(classMissing);
1334 updateStatsForPredictor(pred, instance);
1335 }
1336 return pred;
1337 }
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348 public double evaluateModelOnce(Classifier classifier,
1349 Instance instance) throws Exception {
1350
1351 Instance classMissing = (Instance)instance.copy();
1352 double pred = 0;
1353 classMissing.setDataset(instance.dataset());
1354 classMissing.setClassMissing();
1355 if (m_ClassIsNominal) {
1356 double [] dist = classifier.distributionForInstance(classMissing);
1357 pred = Utils.maxIndex(dist);
1358 if (dist[(int)pred] <= 0) {
1359 pred = Instance.missingValue();
1360 }
1361 updateStatsForClassifier(dist, instance);
1362 } else {
1363 pred = classifier.classifyInstance(classMissing);
1364 updateStatsForPredictor(pred, instance);
1365 }
1366 return pred;
1367 }
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378 public double evaluateModelOnce(double [] dist,
1379 Instance instance) throws Exception {
1380 double pred;
1381 if (m_ClassIsNominal) {
1382 pred = Utils.maxIndex(dist);
1383 if (dist[(int)pred] <= 0) {
1384 pred = Instance.missingValue();
1385 }
1386 updateStatsForClassifier(dist, instance);
1387 } else {
1388 pred = dist[0];
1389 updateStatsForPredictor(pred, instance);
1390 }
1391 return pred;
1392 }
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403 public double evaluateModelOnceAndRecordPrediction(double [] dist,
1404 Instance instance) throws Exception {
1405 double pred;
1406 if (m_ClassIsNominal) {
1407 if (m_Predictions == null) {
1408 m_Predictions = new FastVector();
1409 }
1410 pred = Utils.maxIndex(dist);
1411 if (dist[(int)pred] <= 0) {
1412 pred = Instance.missingValue();
1413 }
1414 updateStatsForClassifier(dist, instance);
1415 m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
1416 instance.weight()));
1417 } else {
1418 pred = dist[0];
1419 updateStatsForPredictor(pred, instance);
1420 }
1421 return pred;
1422 }
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432 public void evaluateModelOnce(double prediction,
1433 Instance instance) throws Exception {
1434
1435 if (m_ClassIsNominal) {
1436 updateStatsForClassifier(makeDistribution(prediction),
1437 instance);
1438 } else {
1439 updateStatsForPredictor(prediction, instance);
1440 }
1441 }
1442
1443
1444
1445
1446
1447
1448
1449
1450 public FastVector predictions() {
1451
1452 return m_Predictions;
1453 }
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465 public static String wekaStaticWrapper(Sourcable classifier, String className)
1466 throws Exception {
1467
1468 StringBuffer result = new StringBuffer();
1469 String staticClassifier = classifier.toSource(className);
1470
1471 result.append("// Generated with Weka " + Version.VERSION + "\n");
1472 result.append("//\n");
1473 result.append("// This code is public domain and comes with no warranty.\n");
1474 result.append("//\n");
1475 result.append("// Timestamp: " + new Date() + "\n");
1476 result.append("\n");
1477 result.append("package weka.classifiers;\n");
1478 result.append("\n");
1479 result.append("import weka.core.Attribute;\n");
1480 result.append("import weka.core.Capabilities;\n");
1481 result.append("import weka.core.Capabilities.Capability;\n");
1482 result.append("import weka.core.Instance;\n");
1483 result.append("import weka.core.Instances;\n");
1484 result.append("import weka.classifiers.Classifier;\n");
1485 result.append("\n");
1486 result.append("public class WekaWrapper\n");
1487 result.append(" extends Classifier {\n");
1488
1489
1490 result.append("\n");
1491 result.append(" /**\n");
1492 result.append(" * Returns only the toString() method.\n");
1493 result.append(" *\n");
1494 result.append(" * @return a string describing the classifier\n");
1495 result.append(" */\n");
1496 result.append(" public String globalInfo() {\n");
1497 result.append(" return toString();\n");
1498 result.append(" }\n");
1499
1500
1501 result.append("\n");
1502 result.append(" /**\n");
1503 result.append(" * Returns the capabilities of this classifier.\n");
1504 result.append(" *\n");
1505 result.append(" * @return the capabilities\n");
1506 result.append(" */\n");
1507 result.append(" public Capabilities getCapabilities() {\n");
1508 result.append(((Classifier) classifier).getCapabilities().toSource("result", 4));
1509 result.append(" return result;\n");
1510 result.append(" }\n");
1511
1512
1513 result.append("\n");
1514 result.append(" /**\n");
1515 result.append(" * only checks the data against its capabilities.\n");
1516 result.append(" *\n");
1517 result.append(" * @param i the training data\n");
1518 result.append(" */\n");
1519 result.append(" public void buildClassifier(Instances i) throws Exception {\n");
1520 result.append(" // can classifier handle the data?\n");
1521 result.append(" getCapabilities().testWithFail(i);\n");
1522 result.append(" }\n");
1523
1524
1525 result.append("\n");
1526 result.append(" /**\n");
1527 result.append(" * Classifies the given instance.\n");
1528 result.append(" *\n");
1529 result.append(" * @param i the instance to classify\n");
1530 result.append(" * @return the classification result\n");
1531 result.append(" */\n");
1532 result.append(" public double classifyInstance(Instance i) throws Exception {\n");
1533 result.append(" Object[] s = new Object[i.numAttributes()];\n");
1534 result.append(" \n");
1535 result.append(" for (int j = 0; j < s.length; j++) {\n");
1536 result.append(" if (!i.isMissing(j)) {\n");
1537 result.append(" if (i.attribute(j).isNominal())\n");
1538 result.append(" s[j] = new String(i.stringValue(j));\n");
1539 result.append(" else if (i.attribute(j).isNumeric())\n");
1540 result.append(" s[j] = new Double(i.value(j));\n");
1541 result.append(" }\n");
1542 result.append(" }\n");
1543 result.append(" \n");
1544 result.append(" // set class value to missing\n");
1545 result.append(" s[i.classIndex()] = null;\n");
1546 result.append(" \n");
1547 result.append(" return " + className + ".classify(s);\n");
1548 result.append(" }\n");
1549
1550
1551 result.append("\n");
1552 result.append(" /**\n");
1553 result.append(" * Returns only the classnames and what classifier it is based on.\n");
1554 result.append(" *\n");
1555 result.append(" * @return a short description\n");
1556 result.append(" */\n");
1557 result.append(" public String toString() {\n");
1558 result.append(" return \"Auto-generated classifier wrapper, based on "
1559 + classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n"
1560 + "\" + this.getClass().getName() + \"/" + className + "\";\n");
1561 result.append(" }\n");
1562
1563
1564 result.append("\n");
1565 result.append(" /**\n");
1566 result.append(" * Runs the classfier from commandline.\n");
1567 result.append(" *\n");
1568 result.append(" * @param args the commandline arguments\n");
1569 result.append(" */\n");
1570 result.append(" public static void main(String args[]) {\n");
1571 result.append(" runClassifier(new WekaWrapper(), args);\n");
1572 result.append(" }\n");
1573 result.append("}\n");
1574
1575
1576 result.append("\n");
1577 result.append(staticClassifier);
1578
1579 return result.toString();
1580 }
1581
1582
1583
1584
1585
1586
1587
1588
1589 public final double numInstances() {
1590
1591 return m_WithClass;
1592 }
1593
1594
1595
1596
1597
1598
1599
1600
1601 public final double incorrect() {
1602
1603 return m_Incorrect;
1604 }
1605
1606
1607
1608
1609
1610
1611
1612
1613 public final double pctIncorrect() {
1614
1615 return 100 * m_Incorrect / m_WithClass;
1616 }
1617
1618
1619
1620
1621
1622
1623
1624 public final double totalCost() {
1625
1626 return m_TotalCost;
1627 }
1628
1629
1630
1631
1632
1633
1634
1635 public final double avgCost() {
1636
1637 return m_TotalCost / m_WithClass;
1638 }
1639
1640
1641
1642
1643
1644
1645
1646
1647 public final double correct() {
1648
1649 return m_Correct;
1650 }
1651
1652
1653
1654
1655
1656
1657
1658 public final double pctCorrect() {
1659
1660 return 100 * m_Correct / m_WithClass;
1661 }
1662
1663
1664
1665
1666
1667
1668
1669
1670 public final double unclassified() {
1671
1672 return m_Unclassified;
1673 }
1674
1675
1676
1677
1678
1679
1680
1681 public final double pctUnclassified() {
1682
1683 return 100 * m_Unclassified / m_WithClass;
1684 }
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694 public final double errorRate() {
1695
1696 if (!m_ClassIsNominal) {
1697 return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
1698 }
1699 if (m_CostMatrix == null) {
1700 return m_Incorrect / m_WithClass;
1701 } else {
1702 return avgCost();
1703 }
1704 }
1705
1706
1707
1708
1709
1710
1711 public final double kappa() {
1712
1713
1714 double[] sumRows = new double[m_ConfusionMatrix.length];
1715 double[] sumColumns = new double[m_ConfusionMatrix.length];
1716 double sumOfWeights = 0;
1717 for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1718 for (int j = 0; j < m_ConfusionMatrix.length; j++) {
1719 sumRows[i] += m_ConfusionMatrix[i][j];
1720 sumColumns[j] += m_ConfusionMatrix[i][j];
1721 sumOfWeights += m_ConfusionMatrix[i][j];
1722 }
1723 }
1724 double correct = 0, chanceAgreement = 0;
1725 for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1726 chanceAgreement += (sumRows[i] * sumColumns[i]);
1727 correct += m_ConfusionMatrix[i][i];
1728 }
1729 chanceAgreement /= (sumOfWeights * sumOfWeights);
1730 correct /= sumOfWeights;
1731
1732 if (chanceAgreement < 1) {
1733 return (correct - chanceAgreement) / (1 - chanceAgreement);
1734 } else {
1735 return 1;
1736 }
1737 }
1738
1739
1740
1741
1742
1743
1744
1745 public final double correlationCoefficient() throws Exception {
1746
1747 if (m_ClassIsNominal) {
1748 throw
1749 new Exception("Can't compute correlation coefficient: " +
1750 "class is nominal!");
1751 }
1752
1753 double correlation = 0;
1754 double varActual =
1755 m_SumSqrClass - m_SumClass * m_SumClass /
1756 (m_WithClass - m_Unclassified);
1757 double varPredicted =
1758 m_SumSqrPredicted - m_SumPredicted * m_SumPredicted /
1759 (m_WithClass - m_Unclassified);
1760 double varProd =
1761 m_SumClassPredicted - m_SumClass * m_SumPredicted /
1762 (m_WithClass - m_Unclassified);
1763
1764 if (varActual * varPredicted <= 0) {
1765 correlation = 0.0;
1766 } else {
1767 correlation = varProd / Math.sqrt(varActual * varPredicted);
1768 }
1769
1770 return correlation;
1771 }
1772
1773
1774
1775
1776
1777
1778
1779
1780 public final double meanAbsoluteError() {
1781
1782 return m_SumAbsErr / (m_WithClass - m_Unclassified);
1783 }
1784
1785
1786
1787
1788
1789
1790 public final double meanPriorAbsoluteError() {
1791
1792 if (m_NoPriors)
1793 return Double.NaN;
1794
1795 return m_SumPriorAbsErr / m_WithClass;
1796 }
1797
1798
1799
1800
1801
1802
1803
1804 public final double relativeAbsoluteError() throws Exception {
1805
1806 if (m_NoPriors)
1807 return Double.NaN;
1808
1809 return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
1810 }
1811
1812
1813
1814
1815
1816
1817 public final double rootMeanSquaredError() {
1818
1819 return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
1820 }
1821
1822
1823
1824
1825
1826
1827 public final double rootMeanPriorSquaredError() {
1828
1829 if (m_NoPriors)
1830 return Double.NaN;
1831
1832 return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
1833 }
1834
1835
1836
1837
1838
1839
1840 public final double rootRelativeSquaredError() {
1841
1842 if (m_NoPriors)
1843 return Double.NaN;
1844
1845 return 100.0 * rootMeanSquaredError() /
1846 rootMeanPriorSquaredError();
1847 }
1848
1849
1850
1851
1852
1853
1854
1855 public final double priorEntropy() throws Exception {
1856
1857 if (!m_ClassIsNominal) {
1858 throw
1859 new Exception("Can't compute entropy of class prior: " +
1860 "class numeric!");
1861 }
1862
1863 if (m_NoPriors)
1864 return Double.NaN;
1865
1866 double entropy = 0;
1867 for(int i = 0; i < m_NumClasses; i++) {
1868 entropy -= m_ClassPriors[i] / m_ClassPriorsSum
1869 * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
1870 }
1871 return entropy;
1872 }
1873
1874
1875
1876
1877
1878
1879
1880 public final double KBInformation() throws Exception {
1881
1882 if (!m_ClassIsNominal) {
1883 throw
1884 new Exception("Can't compute K&B Info score: " +
1885 "class numeric!");
1886 }
1887
1888 if (m_NoPriors)
1889 return Double.NaN;
1890
1891 return m_SumKBInfo;
1892 }
1893
1894
1895
1896
1897
1898
1899
1900
1901 public final double KBMeanInformation() throws Exception {
1902
1903 if (!m_ClassIsNominal) {
1904 throw
1905 new Exception("Can't compute K&B Info score: "
1906 + "class numeric!");
1907 }
1908
1909 if (m_NoPriors)
1910 return Double.NaN;
1911
1912 return m_SumKBInfo / (m_WithClass - m_Unclassified);
1913 }
1914
1915
1916
1917
1918
1919
1920
1921 public final double KBRelativeInformation() throws Exception {
1922
1923 if (!m_ClassIsNominal) {
1924 throw
1925 new Exception("Can't compute K&B Info score: " +
1926 "class numeric!");
1927 }
1928
1929 if (m_NoPriors)
1930 return Double.NaN;
1931
1932 return 100.0 * KBInformation() / priorEntropy();
1933 }
1934
1935
1936
1937
1938
1939
1940 public final double SFPriorEntropy() {
1941
1942 if (m_NoPriors)
1943 return Double.NaN;
1944
1945 return m_SumPriorEntropy;
1946 }
1947
1948
1949
1950
1951
1952
1953 public final double SFMeanPriorEntropy() {
1954
1955 if (m_NoPriors)
1956 return Double.NaN;
1957
1958 return m_SumPriorEntropy / m_WithClass;
1959 }
1960
1961
1962
1963
1964
1965
1966 public final double SFSchemeEntropy() {
1967
1968 if (m_NoPriors)
1969 return Double.NaN;
1970
1971 return m_SumSchemeEntropy;
1972 }
1973
1974
1975
1976
1977
1978
1979 public final double SFMeanSchemeEntropy() {
1980
1981 if (m_NoPriors)
1982 return Double.NaN;
1983
1984 return m_SumSchemeEntropy / (m_WithClass - m_Unclassified);
1985 }
1986
1987
1988
1989
1990
1991
1992
1993 public final double SFEntropyGain() {
1994
1995 if (m_NoPriors)
1996 return Double.NaN;
1997
1998 return m_SumPriorEntropy - m_SumSchemeEntropy;
1999 }
2000
2001
2002
2003
2004
2005
2006
2007 public final double SFMeanEntropyGain() {
2008
2009 if (m_NoPriors)
2010 return Double.NaN;
2011
2012 return (m_SumPriorEntropy - m_SumSchemeEntropy) /
2013 (m_WithClass - m_Unclassified);
2014 }
2015
2016
2017
2018
2019
2020
2021
2022
2023 public String toCumulativeMarginDistributionString() throws Exception {
2024
2025 if (!m_ClassIsNominal) {
2026 throw new Exception("Class must be nominal for margin distributions");
2027 }
2028 String result = "";
2029 double cumulativeCount = 0;
2030 double margin;
2031 for(int i = 0; i <= k_MarginResolution; i++) {
2032 if (m_MarginCounts[i] != 0) {
2033 cumulativeCount += m_MarginCounts[i];
2034 margin = (double)i * 2.0 / k_MarginResolution - 1.0;
2035 result = result + Utils.doubleToString(margin, 7, 3) + ' '
2036 + Utils.doubleToString(cumulativeCount * 100
2037 / m_WithClass, 7, 3) + '\n';
2038 } else if (i == 0) {
2039 result = Utils.doubleToString(-1.0, 7, 3) + ' '
2040 + Utils.doubleToString(0, 7, 3) + '\n';
2041 }
2042 }
2043 return result;
2044 }
2045
2046
2047
2048
2049
2050
2051
2052 public String toSummaryString() {
2053
2054 return toSummaryString("", false);
2055 }
2056
2057
2058
2059
2060
2061
2062
2063
2064 public String toSummaryString(boolean printComplexityStatistics) {
2065
2066 return toSummaryString("=== Summary ===\n", printComplexityStatistics);
2067 }
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081 public String toSummaryString(String title,
2082 boolean printComplexityStatistics) {
2083
2084 StringBuffer text = new StringBuffer();
2085
2086 if (printComplexityStatistics && m_NoPriors) {
2087 printComplexityStatistics = false;
2088 System.err.println("Priors disabled, cannot print complexity statistics!");
2089 }
2090
2091 text.append(title + "\n");
2092 try {
2093 if (m_WithClass > 0) {
2094 if (m_ClassIsNominal) {
2095
2096 text.append("Correctly Classified Instances ");
2097 text.append(Utils.doubleToString(correct(), 12, 4) + " " +
2098 Utils.doubleToString(pctCorrect(),
2099 12, 4) + " %\n");
2100 text.append("Incorrectly Classified Instances ");
2101 text.append(Utils.doubleToString(incorrect(), 12, 4) + " " +
2102 Utils.doubleToString(pctIncorrect(),
2103 12, 4) + " %\n");
2104 text.append("Kappa statistic ");
2105 text.append(Utils.doubleToString(kappa(), 12, 4) + "\n");
2106
2107 if (m_CostMatrix != null) {
2108 text.append("Total Cost ");
2109 text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n");
2110 text.append("Average Cost ");
2111 text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n");
2112 }
2113 if (printComplexityStatistics) {
2114 text.append("K&B Relative Info Score ");
2115 text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4)
2116 + " %\n");
2117 text.append("K&B Information Score ");
2118 text.append(Utils.doubleToString(KBInformation(), 12, 4)
2119 + " bits");
2120 text.append(Utils.doubleToString(KBMeanInformation(), 12, 4)
2121 + " bits/instance\n");
2122 }
2123 } else {
2124 text.append("Correlation coefficient ");
2125 text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
2126 "\n");
2127 }
2128 if (printComplexityStatistics) {
2129 text.append("Class complexity | order 0 ");
2130 text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4)
2131 + " bits");
2132 text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4)
2133 + " bits/instance\n");
2134 text.append("Class complexity | scheme ");
2135 text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4)
2136 + " bits");
2137 text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4)
2138 + " bits/instance\n");
2139 text.append("Complexity improvement (Sf) ");
2140 text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
2141 text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4)
2142 + " bits/instance\n");
2143 }
2144
2145 text.append("Mean absolute error ");
2146 text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4)
2147 + "\n");
2148 text.append("Root mean squared error ");
2149 text.append(Utils.
2150 doubleToString(rootMeanSquaredError(), 12, 4)
2151 + "\n");
2152 if (!m_NoPriors) {
2153 text.append("Relative absolute error ");
2154 text.append(Utils.doubleToString(relativeAbsoluteError(),
2155 12, 4) + " %\n");
2156 text.append("Root relative squared error ");
2157 text.append(Utils.doubleToString(rootRelativeSquaredError(),
2158 12, 4) + " %\n");
2159 }
2160 }
2161 if (Utils.gr(unclassified(), 0)) {
2162 text.append("UnClassified Instances ");
2163 text.append(Utils.doubleToString(unclassified(), 12,4) + " " +
2164 Utils.doubleToString(pctUnclassified(),
2165 12, 4) + " %\n");
2166 }
2167 text.append("Total Number of Instances ");
2168 text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n");
2169 if (m_MissingClass > 0) {
2170 text.append("Ignored Class Unknown Instances ");
2171 text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n");
2172 }
2173 } catch (Exception ex) {
2174
2175
2176 System.err.println("Arggh - Must be a bug in EvaluationACO class");
2177 }
2178
2179 return text.toString();
2180 }
2181
2182
2183
2184
2185
2186
2187
2188 public String toMatrixString() throws Exception {
2189
2190 return toMatrixString("=== Confusion Matrix ===\n");
2191 }
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202 public String toMatrixString(String title) throws Exception {
2203
2204 StringBuffer text = new StringBuffer();
2205 char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
2206 'k','l','m','n','o','p','q','r','s','t',
2207 'u','v','w','x','y','z'};
2208 int IDWidth;
2209 boolean fractional = false;
2210
2211 if (!m_ClassIsNominal) {
2212 throw new Exception("EvaluationACO: No confusion matrix possible!");
2213 }
2214
2215
2216
2217 double maxval = 0;
2218 for(int i = 0; i < m_NumClasses; i++) {
2219 for(int j = 0; j < m_NumClasses; j++) {
2220 double current = m_ConfusionMatrix[i][j];
2221 if (current < 0) {
2222 current *= -10;
2223 }
2224 if (current > maxval) {
2225 maxval = current;
2226 }
2227 double fract = current - Math.rint(current);
2228 if (!fractional
2229 && ((Math.log(fract) / Math.log(10)) >= -2)) {
2230 fractional = true;
2231 }
2232 }
2233 }
2234
2235 IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)
2236 + (fractional ? 3 : 0)),
2237 (int)(Math.log(m_NumClasses) /
2238 Math.log(IDChars.length)));
2239 text.append(title).append("\n");
2240 for(int i = 0; i < m_NumClasses; i++) {
2241 if (fractional) {
2242 text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
2243 .append(" ");
2244 } else {
2245 text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
2246 }
2247 }
2248 text.append(" <-- classified as\n");
2249 for(int i = 0; i< m_NumClasses; i++) {
2250 for(int j = 0; j < m_NumClasses; j++) {
2251 text.append(" ").append(
2252 Utils.doubleToString(m_ConfusionMatrix[i][j],
2253 IDWidth,
2254 (fractional ? 2 : 0)));
2255 }
2256 text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
2257 .append(" = ").append(m_ClassNames[i]).append("\n");
2258 }
2259 return text.toString();
2260 }
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271 public String toClassDetailsString() throws Exception {
2272
2273 return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
2274 }
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286 public String toClassDetailsString(String title) throws Exception {
2287
2288 if (!m_ClassIsNominal) {
2289 throw new Exception("EvaluationACO: No confusion matrix possible!");
2290 }
2291 StringBuffer text = new StringBuffer(title
2292 + "\nTP Rate FP Rate"
2293 + " Precision Recall"
2294 + " F-Measure ROC Area Class\n");
2295 for(int i = 0; i < m_NumClasses; i++) {
2296 text.append(Utils.doubleToString(truePositiveRate(i), 7, 3))
2297 .append(" ");
2298 text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
2299 .append(" ");
2300 text.append(Utils.doubleToString(precision(i), 7, 3))
2301 .append(" ");
2302 text.append(Utils.doubleToString(recall(i), 7, 3))
2303 .append(" ");
2304 text.append(Utils.doubleToString(fMeasure(i), 7, 3))
2305 .append(" ");
2306 double rocVal = areaUnderROC(i);
2307 if (Instance.isMissingValue(rocVal)) {
2308 text.append(" ? ")
2309 .append(" ");
2310 } else {
2311 text.append(Utils.doubleToString(rocVal, 7, 3))
2312 .append(" ");
2313 }
2314 text.append(m_ClassNames[i]).append('\n');
2315 }
2316 return text.toString();
2317 }
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329 public double numTruePositives(int classIndex) {
2330
2331 double correct = 0;
2332 for (int j = 0; j < m_NumClasses; j++) {
2333 if (j == classIndex) {
2334 correct += m_ConfusionMatrix[classIndex][j];
2335 }
2336 }
2337 return correct;
2338 }
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352 public double truePositiveRate(int classIndex) {
2353
2354 double correct = 0, total = 0;
2355 for (int j = 0; j < m_NumClasses; j++) {
2356 if (j == classIndex) {
2357 correct += m_ConfusionMatrix[classIndex][j];
2358 }
2359 total += m_ConfusionMatrix[classIndex][j];
2360 }
2361 if (total == 0) {
2362 return 0;
2363 }
2364 return correct / total;
2365 }
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377 public double numTrueNegatives(int classIndex) {
2378
2379 double correct = 0;
2380 for (int i = 0; i < m_NumClasses; i++) {
2381 if (i != classIndex) {
2382 for (int j = 0; j < m_NumClasses; j++) {
2383 if (j != classIndex) {
2384 correct += m_ConfusionMatrix[i][j];
2385 }
2386 }
2387 }
2388 }
2389 return correct;
2390 }
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404 public double trueNegativeRate(int classIndex) {
2405
2406 double correct = 0, total = 0;
2407 for (int i = 0; i < m_NumClasses; i++) {
2408 if (i != classIndex) {
2409 for (int j = 0; j < m_NumClasses; j++) {
2410 if (j != classIndex) {
2411 correct += m_ConfusionMatrix[i][j];
2412 }
2413 total += m_ConfusionMatrix[i][j];
2414 }
2415 }
2416 }
2417 if (total == 0) {
2418 return 0;
2419 }
2420 return correct / total;
2421 }
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433 public double numFalsePositives(int classIndex) {
2434
2435 double incorrect = 0;
2436 for (int i = 0; i < m_NumClasses; i++) {
2437 if (i != classIndex) {
2438 for (int j = 0; j < m_NumClasses; j++) {
2439 if (j == classIndex) {
2440 incorrect += m_ConfusionMatrix[i][j];
2441 }
2442 }
2443 }
2444 }
2445 return incorrect;
2446 }
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460 public double falsePositiveRate(int classIndex) {
2461
2462 double incorrect = 0, total = 0;
2463 for (int i = 0; i < m_NumClasses; i++) {
2464 if (i != classIndex) {
2465 for (int j = 0; j < m_NumClasses; j++) {
2466 if (j == classIndex) {
2467 incorrect += m_ConfusionMatrix[i][j];
2468 }
2469 total += m_ConfusionMatrix[i][j];
2470 }
2471 }
2472 }
2473 if (total == 0) {
2474 return 0;
2475 }
2476 return incorrect / total;
2477 }
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489 public double numFalseNegatives(int classIndex) {
2490
2491 double incorrect = 0;
2492 for (int i = 0; i < m_NumClasses; i++) {
2493 if (i == classIndex) {
2494 for (int j = 0; j < m_NumClasses; j++) {
2495 if (j != classIndex) {
2496 incorrect += m_ConfusionMatrix[i][j];
2497 }
2498 }
2499 }
2500 }
2501 return incorrect;
2502 }
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516 public double falseNegativeRate(int classIndex) {
2517
2518 double incorrect = 0, total = 0;
2519 for (int i = 0; i < m_NumClasses; i++) {
2520 if (i == classIndex) {
2521 for (int j = 0; j < m_NumClasses; j++) {
2522 if (j != classIndex) {
2523 incorrect += m_ConfusionMatrix[i][j];
2524 }
2525 total += m_ConfusionMatrix[i][j];
2526 }
2527 }
2528 }
2529 if (total == 0) {
2530 return 0;
2531 }
2532 return incorrect / total;
2533 }
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548 public double recall(int classIndex) {
2549
2550 return truePositiveRate(classIndex);
2551 }
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565 public double precision(int classIndex) {
2566
2567 double correct = 0, total = 0;
2568 for (int i = 0; i < m_NumClasses; i++) {
2569 if (i == classIndex) {
2570 correct += m_ConfusionMatrix[i][classIndex];
2571 }
2572 total += m_ConfusionMatrix[i][classIndex];
2573 }
2574 if (total == 0) {
2575 return 0;
2576 }
2577 return correct / total;
2578 }
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592 public double fMeasure(int classIndex) {
2593
2594 double precision = precision(classIndex);
2595 double recall = recall(classIndex);
2596 if ((precision + recall) == 0) {
2597 return 0;
2598 }
2599 return 2 * precision * recall / (precision + recall);
2600 }
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610 public void setPriors(Instances train) throws Exception {
2611 m_NoPriors = false;
2612
2613 if (!m_ClassIsNominal) {
2614
2615 m_NumTrainClassVals = 0;
2616 m_TrainClassVals = null;
2617 m_TrainClassWeights = null;
2618 m_PriorErrorEstimator = null;
2619 m_ErrorEstimator = null;
2620
2621 for (int i = 0; i < train.numInstances(); i++) {
2622 Instance currentInst = train.instance(i);
2623 if (!currentInst.classIsMissing()) {
2624 addNumericTrainClass(currentInst.classValue(),
2625 currentInst.weight());
2626 }
2627 }
2628
2629 } else {
2630 for (int i = 0; i < m_NumClasses; i++) {
2631 m_ClassPriors[i] = 1;
2632 }
2633 m_ClassPriorsSum = m_NumClasses;
2634 for (int i = 0; i < train.numInstances(); i++) {
2635 if (!train.instance(i).classIsMissing()) {
2636 m_ClassPriors[(int)train.instance(i).classValue()] +=
2637 train.instance(i).weight();
2638 m_ClassPriorsSum += train.instance(i).weight();
2639 }
2640 }
2641 }
2642 }
2643
2644
2645
2646
2647
2648
2649 public double [] getClassPriors() {
2650 return m_ClassPriors;
2651 }
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661 public void updatePriors(Instance instance) throws Exception {
2662 if (!instance.classIsMissing()) {
2663 if (!m_ClassIsNominal) {
2664 if (!instance.classIsMissing()) {
2665 addNumericTrainClass(instance.classValue(),
2666 instance.weight());
2667 }
2668 } else {
2669 m_ClassPriors[(int)instance.classValue()] +=
2670 instance.weight();
2671 m_ClassPriorsSum += instance.weight();
2672 }
2673 }
2674 }
2675
2676
2677
2678
2679
2680
2681 public void useNoPriors() {
2682 m_NoPriors = true;
2683 }
2684
2685
2686
2687
2688
2689
2690
2691
2692 public boolean equals(Object obj) {
2693
2694 if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
2695 return false;
2696 }
2697 EvaluationACO cmp = (EvaluationACO) obj;
2698 if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
2699 if (m_NumClasses != cmp.m_NumClasses) return false;
2700
2701 if (m_Incorrect != cmp.m_Incorrect) return false;
2702 if (m_Correct != cmp.m_Correct) return false;
2703 if (m_Unclassified != cmp.m_Unclassified) return false;
2704 if (m_MissingClass != cmp.m_MissingClass) return false;
2705 if (m_WithClass != cmp.m_WithClass) return false;
2706
2707 if (m_SumErr != cmp.m_SumErr) return false;
2708 if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
2709 if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
2710 if (m_SumClass != cmp.m_SumClass) return false;
2711 if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
2712 if (m_SumPredicted != cmp.m_SumPredicted) return false;
2713 if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
2714 if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
2715
2716 if (m_ClassIsNominal) {
2717 for (int i = 0; i < m_NumClasses; i++) {
2718 for (int j = 0; j < m_NumClasses; j++) {
2719 if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
2720 return false;
2721 }
2722 }
2723 }
2724 }
2725
2726 return true;
2727 }
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742 protected static String printClassifications(Classifier classifier,
2743 Instances train,
2744 DataSource testSource,
2745 int classIndex,
2746 Range attributesToOutput) throws Exception {
2747
2748 return printClassifications(
2749 classifier, train, testSource, classIndex, attributesToOutput, false);
2750 }
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767 protected static String printClassifications(Classifier classifier,
2768 Instances train,
2769 DataSource testSource,
2770 int classIndex,
2771 Range attributesToOutput,
2772 boolean printDistribution) throws Exception {
2773
2774 StringBuffer text = new StringBuffer();
2775 if (testSource != null) {
2776 Instances test = testSource.getStructure();
2777 if (classIndex != -1) {
2778 test.setClassIndex(classIndex - 1);
2779 } else {
2780 if (test.classIndex() == -1)
2781 test.setClassIndex(test.numAttributes() - 1);
2782 }
2783
2784
2785 if (test.classAttribute().isNominal())
2786 if (printDistribution)
2787 text.append(" inst# actual predicted error distribution");
2788 else
2789 text.append(" inst# actual predicted error prediction");
2790 else
2791 text.append(" inst# actual predicted error");
2792 if (attributesToOutput != null) {
2793 attributesToOutput.setUpper(test.numAttributes() - 1);
2794 text.append(" (");
2795 boolean first = true;
2796 for (int i = 0; i < test.numAttributes(); i++) {
2797 if (i == test.classIndex())
2798 continue;
2799
2800 if (attributesToOutput.isInRange(i)) {
2801 if (!first)
2802 text.append(",");
2803 text.append(test.attribute(i).name());
2804 first = false;
2805 }
2806 }
2807 text.append(")");
2808 }
2809 text.append("\n");
2810
2811
2812 int i = 0;
2813 testSource.reset();
2814 test = testSource.getStructure(test.classIndex());
2815 while (testSource.hasMoreElements(test)) {
2816 Instance inst = testSource.nextElement(test);
2817 text.append(
2818 predictionText(
2819 classifier, inst, i, attributesToOutput, printDistribution));
2820 i++;
2821 }
2822 }
2823 return text.toString();
2824 }
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839 protected static String predictionText(Classifier classifier,
2840 Instance inst,
2841 int instNum,
2842 Range attributesToOutput,
2843 boolean printDistribution)
2844 throws Exception {
2845
2846 StringBuffer result = new StringBuffer();
2847 int width = 10;
2848 int prec = 3;
2849
2850 Instance withMissing = (Instance)inst.copy();
2851 withMissing.setDataset(inst.dataset());
2852 double predValue = ((Classifier)classifier).classifyInstance(withMissing);
2853
2854
2855 result.append(Utils.padLeft("" + (instNum+1), 6));
2856
2857 if (inst.dataset().classAttribute().isNumeric()) {
2858
2859 if (inst.classIsMissing())
2860 result.append(" " + Utils.padLeft("?", width));
2861 else
2862 result.append(" " + Utils.doubleToString(inst.classValue(), width, prec));
2863
2864 if (Instance.isMissingValue(predValue))
2865 result.append(" " + Utils.padLeft("?", width));
2866 else
2867 result.append(" " + Utils.doubleToString(predValue, width, prec));
2868
2869 if (Instance.isMissingValue(predValue) || inst.classIsMissing())
2870 result.append(" " + Utils.padLeft("?", width));
2871 else
2872 result.append(" " + Utils.doubleToString(predValue - inst.classValue(), width, prec));
2873 } else {
2874
2875 result.append(" " + Utils.padLeft(((int) inst.classValue()+1) + ":" + inst.toString(inst.classIndex()), width));
2876
2877 if (Instance.isMissingValue(predValue))
2878 result.append(" " + Utils.padLeft("?", width));
2879 else
2880 result.append(" " + Utils.padLeft(((int) predValue+1) + ":" + inst.dataset().classAttribute().value((int)predValue), width));
2881
2882 if ((int) predValue+1 != (int) inst.classValue()+1)
2883 result.append(" " + " + ");
2884 else
2885 result.append(" " + " ");
2886
2887 if (printDistribution) {
2888 if (Instance.isMissingValue(predValue)) {
2889 result.append(" " + "?");
2890 }
2891 else {
2892 result.append(" ");
2893 double[] dist = classifier.distributionForInstance(withMissing);
2894 for (int n = 0; n < dist.length; n++) {
2895 if (n > 0)
2896 result.append(",");
2897 if (n == (int) predValue)
2898 result.append("*");
2899 result.append(Utils.doubleToString(dist[n], prec));
2900 }
2901 }
2902 }
2903 else {
2904 if (Instance.isMissingValue(predValue))
2905 result.append(" " + "?");
2906 else
2907 result.append(" " + Utils.doubleToString(classifier.distributionForInstance(withMissing) [(int)predValue], prec));
2908 }
2909 }
2910
2911
2912 result.append(" " + attributeValuesString(withMissing, attributesToOutput) + "\n");
2913
2914 return result.toString();
2915 }
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925 protected static String attributeValuesString(Instance instance, Range attRange) {
2926 StringBuffer text = new StringBuffer();
2927 if (attRange != null) {
2928 boolean firstOutput = true;
2929 attRange.setUpper(instance.numAttributes() - 1);
2930 for (int i=0; i<instance.numAttributes(); i++)
2931 if (attRange.isInRange(i) && i != instance.classIndex()) {
2932 if (firstOutput) text.append("(");
2933 else text.append(",");
2934 text.append(instance.toString(i));
2935 firstOutput = false;
2936 }
2937 if (!firstOutput) text.append(")");
2938 }
2939 return text.toString();
2940 }
2941
2942
2943
2944
2945
2946
2947
2948 protected static String makeOptionString(Classifier classifier) {
2949
2950 StringBuffer optionsText = new StringBuffer("");
2951
2952
2953 optionsText.append("\n\nGeneral options:\n\n");
2954 optionsText.append("-t <name of training file>\n");
2955 optionsText.append("\tSets training file.\n");
2956 optionsText.append("-T <name of test file>\n");
2957 optionsText.append("\tSets test file. If missing, a cross-validation will be performed\n");
2958 optionsText.append("\ton the training data.\n");
2959 optionsText.append("-c <class index>\n");
2960 optionsText.append("\tSets index of class attribute (default: last).\n");
2961 optionsText.append("-x <number of folds>\n");
2962 optionsText.append("\tSets number of folds for cross-validation (default: 10).\n");
2963 optionsText.append("-no-cv\n");
2964 optionsText.append("\tDo not perform any cross validation.\n");
2965 optionsText.append("-split-percentage <percentage>\n");
2966 optionsText.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
2967 optionsText.append("-preserve-order\n");
2968 optionsText.append("\tPreserves the order in the percentage split.\n");
2969 optionsText.append("-s <random number seed>\n");
2970 optionsText.append("\tSets random number seed for cross-validation or percentage split\n");
2971 optionsText.append("\t(default: 1).\n");
2972 optionsText.append("-m <name of file with cost matrix>\n");
2973 optionsText.append("\tSets file with cost matrix.\n");
2974 optionsText.append("-l <name of input file>\n");
2975 optionsText.append("\tSets model input file. In case the filename ends with '.xml',\n");
2976 optionsText.append("\tthe options are loaded from the XML file.\n");
2977 optionsText.append("-d <name of output file>\n");
2978 optionsText.append("\tSets model output file. In case the filename ends with '.xml',\n");
2979 optionsText.append("\tonly the options are saved to the XML file, not the model.\n");
2980 optionsText.append("-v\n");
2981 optionsText.append("\tOutputs no statistics for training data.\n");
2982 optionsText.append("-o\n");
2983 optionsText.append("\tOutputs statistics only, not the classifier.\n");
2984 optionsText.append("-i\n");
2985 optionsText.append("\tOutputs detailed information-retrieval");
2986 optionsText.append(" statistics for each class.\n");
2987 optionsText.append("-k\n");
2988 optionsText.append("\tOutputs information-theoretic statistics.\n");
2989 optionsText.append("-p <attribute range>\n");
2990 optionsText.append("\tOnly outputs predictions for test instances (or the train\n"
2991 + "\tinstances if no test instances provided), along with attributes\n"
2992 + "\t(0 for none).\n");
2993 optionsText.append("-distribution\n");
2994 optionsText.append("\tOutputs the distribution instead of only the prediction\n");
2995 optionsText.append("\tin conjunction with the '-p' option (only nominal classes).\n");
2996 optionsText.append("-r\n");
2997 optionsText.append("\tOnly outputs cumulative margin distribution.\n");
2998 if (classifier instanceof Sourcable) {
2999 optionsText.append("-z <class name>\n");
3000 optionsText.append("\tOnly outputs the source representation"
3001 + " of the classifier,\n\tgiving it the supplied"
3002 + " name.\n");
3003 }
3004 if (classifier instanceof Drawable) {
3005 optionsText.append("-g\n");
3006 optionsText.append("\tOnly outputs the graph representation"
3007 + " of the classifier.\n");
3008 }
3009 optionsText.append("-xml filename | xml-string\n");
3010 optionsText.append("\tRetrieves the options from the XML-data instead of the "
3011 + "command line.\n");
3012 optionsText.append("-threshold-file <file>\n");
3013 optionsText.append("\tThe file to save the threshold data to.\n"
3014 + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
3015 + "\tformat or '.csv' for CSV.\n");
3016 optionsText.append("-threshold-label <label>\n");
3017 optionsText.append("\tThe class label to determine the threshold data for\n"
3018 + "\t(default is the first label)\n");
3019
3020
3021 if (classifier instanceof OptionHandler) {
3022 optionsText.append("\nOptions specific to "
3023 + classifier.getClass().getName()
3024 + ":\n\n");
3025 Enumeration enu = ((OptionHandler)classifier).listOptions();
3026 while (enu.hasMoreElements()) {
3027 Option option = (Option) enu.nextElement();
3028 optionsText.append(option.synopsis() + '\n');
3029 optionsText.append(option.description() + "\n");
3030 }
3031 }
3032 return optionsText.toString();
3033 }
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043 protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
3044
3045 char ID [] = new char [IDWidth];
3046 int i;
3047
3048 for(i = IDWidth - 1; i >=0; i--) {
3049 ID[i] = IDChars[num % IDChars.length];
3050 num = num / IDChars.length - 1;
3051 if (num < 0) {
3052 break;
3053 }
3054 }
3055 for(i--; i >= 0; i--) {
3056 ID[i] = ' ';
3057 }
3058
3059 return new String(ID);
3060 }
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070 protected double [] makeDistribution(double predictedClass) {
3071
3072 double [] result = new double [m_NumClasses];
3073 if (Instance.isMissingValue(predictedClass)) {
3074 return result;
3075 }
3076 if (m_ClassIsNominal) {
3077 result[(int)predictedClass] = 1.0;
3078 } else {
3079 result[0] = predictedClass;
3080 }
3081 return result;
3082 }
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094 protected void updateStatsForClassifier(double [] predictedDistribution,
3095 Instance instance)
3096 throws Exception {
3097
3098 int actualClass = (int)instance.classValue();
3099
3100 if (!instance.classIsMissing()) {
3101 updateMargins(predictedDistribution, actualClass, instance.weight());
3102
3103
3104
3105 int predictedClass = -1;
3106 double bestProb = 0.0;
3107 for(int i = 0; i < m_NumClasses; i++) {
3108 if (predictedDistribution[i] > bestProb) {
3109 predictedClass = i;
3110 bestProb = predictedDistribution[i];
3111 }
3112 }
3113
3114 m_WithClass += instance.weight();
3115
3116
3117 if (m_CostMatrix != null) {
3118 if (predictedClass < 0) {
3119
3120
3121
3122
3123
3124 m_TotalCost += instance.weight()
3125 * m_CostMatrix.getMaxCost(actualClass, instance);
3126 } else {
3127 m_TotalCost += instance.weight()
3128 * m_CostMatrix.getElement(actualClass, predictedClass,
3129 instance);
3130 }
3131 }
3132
3133
3134 if (predictedClass < 0) {
3135 m_Unclassified += instance.weight();
3136 return;
3137 }
3138
3139 double predictedProb = Math.max(MIN_SF_PROB,
3140 predictedDistribution[actualClass]);
3141 double priorProb = Math.max(MIN_SF_PROB,
3142 m_ClassPriors[actualClass]
3143 / m_ClassPriorsSum);
3144 if (predictedProb >= priorProb) {
3145 m_SumKBInfo += (Utils.log2(predictedProb) -
3146 Utils.log2(priorProb))
3147 * instance.weight();
3148 } else {
3149 m_SumKBInfo -= (Utils.log2(1.0-predictedProb) -
3150 Utils.log2(1.0-priorProb))
3151 * instance.weight();
3152 }
3153
3154 m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
3155 m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
3156
3157 updateNumericScores(predictedDistribution,
3158 makeDistribution(instance.classValue()),
3159 instance.weight());
3160
3161
3162 m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
3163 if (predictedClass != actualClass) {
3164 m_Incorrect += instance.weight();
3165 } else {
3166 m_Correct += instance.weight();
3167 }
3168 } else {
3169 m_MissingClass += instance.weight();
3170 }
3171 }
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182 protected void updateStatsForPredictor(double predictedValue,
3183 Instance instance)
3184 throws Exception {
3185
3186 if (!instance.classIsMissing()){
3187
3188
3189 m_WithClass += instance.weight();
3190 if (Instance.isMissingValue(predictedValue)) {
3191 m_Unclassified += instance.weight();
3192 return;
3193 }
3194 m_SumClass += instance.weight() * instance.classValue();
3195 m_SumSqrClass += instance.weight() * instance.classValue()
3196 * instance.classValue();
3197 m_SumClassPredicted += instance.weight()
3198 * instance.classValue() * predictedValue;
3199 m_SumPredicted += instance.weight() * predictedValue;
3200 m_SumSqrPredicted += instance.weight() * predictedValue * predictedValue;
3201
3202 if (m_ErrorEstimator == null) {
3203 setNumericPriorsFromBuffer();
3204 }
3205 double predictedProb = Math.max(m_ErrorEstimator.getProbability(
3206 predictedValue
3207 - instance.classValue()),
3208 MIN_SF_PROB);
3209 double priorProb = Math.max(m_PriorErrorEstimator.getProbability(
3210 instance.classValue()),
3211 MIN_SF_PROB);
3212
3213 m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
3214 m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
3215 m_ErrorEstimator.addValue(predictedValue - instance.classValue(),
3216 instance.weight());
3217
3218 updateNumericScores(makeDistribution(predictedValue),
3219 makeDistribution(instance.classValue()),
3220 instance.weight());
3221
3222 } else
3223 m_MissingClass += instance.weight();
3224 }
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234 protected void updateMargins(double [] predictedDistribution,
3235 int actualClass, double weight) {
3236
3237 double probActual = predictedDistribution[actualClass];
3238 double probNext = 0;
3239
3240 for(int i = 0; i < m_NumClasses; i++)
3241 if ((i != actualClass) &&
3242 (predictedDistribution[i] > probNext))
3243 probNext = predictedDistribution[i];
3244
3245 double margin = probActual - probNext;
3246 int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
3247 m_MarginCounts[bin] += weight;
3248 }
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260 protected void updateNumericScores(double [] predicted,
3261 double [] actual, double weight) {
3262
3263 double diff;
3264 double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
3265 double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
3266 for(int i = 0; i < m_NumClasses; i++) {
3267 diff = predicted[i] - actual[i];
3268 sumErr += diff;
3269 sumAbsErr += Math.abs(diff);
3270 sumSqrErr += diff * diff;
3271 diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
3272 sumPriorAbsErr += Math.abs(diff);
3273 sumPriorSqrErr += diff * diff;
3274 }
3275 m_SumErr += weight * sumErr / m_NumClasses;
3276 m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
3277 m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
3278 m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
3279 m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
3280 }
3281
3282
3283
3284
3285
3286
3287
3288
3289 protected void addNumericTrainClass(double classValue, double weight) {
3290
3291 if (m_TrainClassVals == null) {
3292 m_TrainClassVals = new double [100];
3293 m_TrainClassWeights = new double [100];
3294 }
3295 if (m_NumTrainClassVals == m_TrainClassVals.length) {
3296 double [] temp = new double [m_TrainClassVals.length * 2];
3297 System.arraycopy(m_TrainClassVals, 0,
3298 temp, 0, m_TrainClassVals.length);
3299 m_TrainClassVals = temp;
3300
3301 temp = new double [m_TrainClassWeights.length * 2];
3302 System.arraycopy(m_TrainClassWeights, 0,
3303 temp, 0, m_TrainClassWeights.length);
3304 m_TrainClassWeights = temp;
3305 }
3306 m_TrainClassVals[m_NumTrainClassVals] = classValue;
3307 m_TrainClassWeights[m_NumTrainClassVals] = weight;
3308 m_NumTrainClassVals++;
3309 }
3310
3311
3312
3313
3314
3315 protected void setNumericPriorsFromBuffer() {
3316
3317 double numPrecision = 0.01;
3318 if (m_NumTrainClassVals > 1) {
3319 double [] temp = new double [m_NumTrainClassVals];
3320 System.arraycopy(m_TrainClassVals, 0, temp, 0, m_NumTrainClassVals);
3321 int [] index = Utils.sort(temp);
3322 double lastVal = temp[index[0]];
3323 double deltaSum = 0;
3324 int distinct = 0;
3325 for (int i = 1; i < temp.length; i++) {
3326 double current = temp[index[i]];
3327 if (current != lastVal) {
3328 deltaSum += current - lastVal;
3329 lastVal = current;
3330 distinct++;
3331 }
3332 }
3333 if (distinct > 0) {
3334 numPrecision = deltaSum / distinct;
3335 }
3336 }
3337 m_PriorErrorEstimator = new KernelEstimator(numPrecision);
3338 m_ErrorEstimator = new KernelEstimator(numPrecision);
3339 m_ClassPriors[0] = m_ClassPriorsSum = 0;
3340 for (int i = 0; i < m_NumTrainClassVals; i++) {
3341 m_ClassPriors[0] += m_TrainClassVals[i] * m_TrainClassWeights[i];
3342 m_ClassPriorsSum += m_TrainClassWeights[i];
3343 m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
3344 m_TrainClassWeights[i]);
3345 }
3346 }
3347 }