tJavaで評価プログラムを作成する
手順
- tJavaをダブルクリックして、その[Component] (コンポーネント)ビューを開きます。
- tJavaがtClassifyの複製されたスキーマを確実に取得するために、[Sync columns] (カラムを同期)をクリックします。
-
[Advanced settings] (詳細設定)タブをクリックして、ビューを開きます。
-
[Classes] (クラス)フィールドにコードを入力して、予測されたクラスラベルが実際のクラスラベルと一致するかどうかを確認するために使うJavaクラスを定義します(ジャンクメッセージにはspam、通常のメッセージにはham)。このシナリオでは、row7はtClassifyとtReplicateの間の接続のIDであり、後続のコンポーネントに送信される分類結果を保持します。また、row7Structは分類結果のRDDのJavaクラスです。コードに含まれているrow7は、単独で使うかrow7Struct内で使うかに関係なく、ジョブに使われている対応する接続IDに置き換える必要があります。
reallabelやlabelなどのカラム名は、さまざまなコンポーネントを設定した前のステップで定義済したものです。これらに異なる名前を付けた場合は、コードで使うために一貫性のある状態に保つ必要があります。
public static class SpamFilterFunction implements org.apache.spark.api.java.function.Function<row7Struct, Boolean>{ private static final long serialVersionUID = 1L; @Override public Boolean call(row7Struct row7) throws Exception { return row7.reallabel.equals("spam"); } } // 'negative': ham // 'positive': spam // 'false' means the real label & predicted label are different // 'true' means the real label & predicted label are the same public static class TrueNegativeFunction implements org.apache.spark.api.java.function.Function<row7Struct, Boolean>{ private static final long serialVersionUID = 1L; @Override public Boolean call(row7Struct row7) throws Exception { return (row7.label.equals("ham") && row7.reallabel.equals("ham")); } } public static class TruePositiveFunction implements org.apache.spark.api.java.function.Function<row7Struct, Boolean>{ private static final long serialVersionUID = 1L; @Override public Boolean call(row7Struct row7) throws Exception { // true positive cases return (row7.label.equals("spam") && row7.reallabel.equals("spam")); } } public static class FalseNegativeFunction implements org.apache.spark.api.java.function.Function<row7Struct, Boolean>{ private static final long serialVersionUID = 1L; @Override public Boolean call(row7Struct row7) throws Exception { // false positive cases return (row7.label.equals("spam") && row7.reallabel.equals("ham")); } } public static class FalsePositiveFunction implements org.apache.spark.api.java.function.Function<row7Struct, Boolean>{ private static final long serialVersionUID = 1L; @Override public Boolean call(row7Struct row7) throws Exception { // false positive cases return (row7.label.equals("ham") && row7.reallabel.equals("spam")); } }
-
[Basic settings] (基本設定)タブをクリックしてビューを開き、[Code] (コード)フィールドに、分類モデルの精度スコアとMatthewsコリレーション係数(MCC)の計算に使うコードを入力します。
Mathewsコリレーション係数に関する一般的な説明は、Wikipediaのhttps://en.wikipedia.org/wiki/Matthews_correlation_coefficient (英語のみ)をご覧ください。
long nbTotal = rdd_tJava_1.count(); long nbSpam = rdd_tJava_1.filter(new SpamFilterFunction()).count(); long nbHam = nbTotal - nbSpam; // 'negative': ham // 'positive': spam // 'false' means the real label & predicted label are different // 'true' means the real label & predicted label are the same long tn = rdd_tJava_1.filter(new TrueNegativeFunction()).count(); long tp = rdd_tJava_1.filter(new TruePositiveFunction()).count(); long fn = rdd_tJava_1.filter(new FalseNegativeFunction()).count(); long fp = rdd_tJava_1.filter(new FalsePositiveFunction()).count(); double mmc = (double)(tp*tn -fp*fn) / java.lang.Math.sqrt((double)((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))); System.out.println("Accuracy:"+((double)(tp+tn)/(double)nbTotal)); System.out.println("Spams caught (SC):"+((double)tp/(double)nbSpam)); System.out.println("Blocked hams (BH):"+((double)fp/(double)nbHam)); System.out.println("Matthews correlation coefficient (MCC):" + mmc);