メイン コンテンツをスキップする 補完的コンテンツへスキップ

分類モデルを評価する

分類モデルを作成したら、それがどの程度優れているかを評価できるようになります。

コンポーネントをリンク

手順

  1. Talend StudioIntegrationパースペクティブで、[Repository] (リポジトリー)ツリービューの[Job Designs] (ジョブ設計)ノードから、たとえばrf_model_creationという名前の空のSpark Batchジョブを作成します。
  2. 使用するコンポーネントの名前をワークスペースに入力し、表示されるリストからそのコンポーネントを選択します。
    このジョブでは、コンポーネントはtHDFSConfigurationtFileInputDelimitedtPredicttReplicatetJavatFilterColumnstLogRowです。
  3. 上の画像に示すように、tHDFSConfigurationを除き、[Row] (行) > [Main] (メイン)リンクを使って接続します。
    tPredictコンポーネントを使った7コンポーネントのジョブ。
  4. tHDFSConfigurationをダブルクリックして[Component] (コンポーネント)ビューを開き、このシナリオで前述したように設定します。

テストセットをジョブにロードする

手順

  1. tFileInputDelimitedをダブルクリックして、[Component] (コンポーネント)ビューを開きます。
  2. [Define a storage configuration component] (ストレージ設定コンポーネントを定義)チェックボックスをオンにし、使うtHDFSConfigurationコンポーネントを選択します。
    tFileInputDelimitedはこの設定を使い、使うトレーニングセットにアクセスします。
  3. [Edit schema] (スキーマを編集)の横にある[...]ボタンをクリックし、スキーマエディターを開きます。
  4. [+]ボタンを5回クリックして5つの行を追加し、[Column] (カラム)カラムで名前をそれぞれreallabelsms_contentsnum_currencynum_numericnum_exclamationに変更します。
    reallabelカラムとsms_contentsカラムは、sms_contentsカラム内のSMSテキストメッセージで構成された生データを保持し、メッセージがスパムかどうかをreallabelカラムでラベル表示します。
    他のカラムは、このシナリオで前に説明したように、生データセットに追加された機能を保持するために使われます。これらのカラムには、各SMSメッセージにある通貨記号の数、数値の数、感嘆符の数が含まれています。
  5. [Type] (タイプ)カラムで、カラムnum_currencynum_numericnum_exclamation[Integer] (整数)を選択します。
  6. [OK]をクリックして、これらの変更を検証します。
  7. [Folder/File] (フォルダー/ファイル)フィールドに、使うテストセットが保管されているディレクトリーを入力します。
  8. [Field separator] (フィールド区切り)フィールドに\tを入力します。これはデータセットが使う区切りで、このシナリオ用にダウンロードできます。

分類モデルを適用する

手順

  1. tPredictをダブルクリックして[Basic settings] (基本設定)を開きます。
  2. [Model Type] (モデルタイプ)で、[Random Forest Model] (ランダムフォレストモデル)を選択します。
  3. [Model on filesystem] (ファイルシステムのモデル)ラジオボタンを選択し、使う分類モデルが保管されているディレクトリーを入力します。
    tPredictコンポーネントには、labelと呼ばれる読み取り専用のカラムが含まれています。このモデルでは分類プロセスで使われるクラスが提供されますが、入力スキーマから取得されたreallabelカラムには、各メッセージが実際に属するクラスが含まれています。各メッセージの実際のラベルをモデルが決定するラベルと比較することにより、モデルが評価されます。

分類結果を複製する

手順

  1. tReplicateをダブルクリックして、[Component] (コンポーネント)ビューを開きます。
  2. デフォルト設定のままにします。

分類結果をフィルタリングする

手順

  1. tFilterColumnsをダブルクリックして、[Component] (コンポーネント)ビューを開きます。
  2. [Edit schema] (スキーマを編集)の横にある[...]ボタンをクリックし、スキーマエディターを開きます。
  3. 出力側で[+]ボタンを3回クリックして3つの行を追加し、[Column] (カラム)カラムで3つの行の名前をそれぞれ、reallabellabelsms_contentsに変更します。同じ名前を使っている入力カラムからデータを受け取ります。
  4. OKをクリックしてこれらの変更を確定し、ポップアップ表示されるダイアログボックスで求められるプロパゲーションを承認します。

tJavaで評価プログラムを作成する

手順

  1. tJavaをダブルクリックして、その[Component] (コンポーネント)ビューを開きます。
  2. tJavatPredictの複製されたスキーマを確実に取得するよう、[Sync columns] (カラムを同期)をクリックします。
  3. [Advanced settings] (詳細設定)タブをクリックして、ビューを開きます。
  4. [Classes] (クラス)フィールドにコードを入力して、予測されたクラスラベルが実際のクラスラベルと一致するかどうかを確認するために使うJavaクラスを定義します。
    ジャンクメッセージではspam、通常のメッセージではhamとなります。
    このシナリオでは、row7tPredicttReplicateの間の接続のIDであり、後続のコンポーネントに送信される分類結果を保持します。row7Structは分類結果のRDDのJavaクラスです。コードに含まれているrow7は、単独で使うかrow7Structで使うかに関係なく、ジョブに使われている対応する接続IDに置き換える必要があります。
    reallabellabelなどのカラム名は、さまざまなコンポーネントを設定した前のステップで定義済したものです。これらに異なる名前を付けた場合は、コードで使うために一貫性のある状態に保つ必要があります。
    public static class SpamFilterFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return row7.reallabel.equals("spam");
    	}
    	
    }
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    public static class TrueNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return (row7.label.equals("ham") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class TruePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// true positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("spam"));
    	}
    	
    }
    
    public static class FalseNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class FalsePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("ham") && row7.reallabel.equals("spam"));
    	}
    	
    }
  5. [Basic settings] (基本設定)タブをクリックし、[Code] (コード)フィールドに、分類モデルの精度スコアとMatthewsコリレーション係数(MCC)の計算に使うコードを入力します。
    Mathewsコリレーション係数に関する一般的な説明は、Wikipediaのhttps://en.wikipedia.org/wiki/Matthews_correlation_coefficientをご覧ください。
    long nbTotal = rdd_tJava_1.count();
    
    long nbSpam = rdd_tJava_1.filter(new SpamFilterFunction()).count();
    
    long nbHam = nbTotal - nbSpam;
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    long tn = rdd_tJava_1.filter(new TrueNegativeFunction()).count();
    
    long tp = rdd_tJava_1.filter(new TruePositiveFunction()).count();
    
    long fn = rdd_tJava_1.filter(new FalseNegativeFunction()).count();
    
    long fp = rdd_tJava_1.filter(new FalsePositiveFunction()).count();
    
    double mmc = (double)(tp*tn -fp*fn) / java.lang.Math.sqrt((double)((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)));
    
    System.out.println("Accuracy:"+((double)(tp+tn)/(double)nbTotal));
    System.out.println("Spams caught (SC):"+((double)tp/(double)nbSpam));
    System.out.println("Blocked hams (BH):"+((double)fp/(double)nbHam));
    System.out.println("Matthews correlation coefficient (MCC):" + mmc); 

Spark接続を設定する

このタスクについて

上記の操作を繰り返します。Sparkモードを選択をご覧ください。

ジョブを実行

手順

  1. tLogRowコンポーネントを使って、ジョブの実行結果を表示します。
    [Component] (コンポーネント)ビューでプレゼンテーションモードを設定する場合は、tLogRowコンポーネントをダブルクリックして[Component] (コンポーネント)ビューを開き、次に[Mode] (モード)エリアで[Table (print values in cells of a table)] (テーブル(テーブルのセルの出力値))ラジオボタンを選択します。
  2. [Run] (実行)ビューのコンソールでLog4j ロギングの唯一のエラーレベル情報を表示する必要がある場合は、[Run] (実行)をクリックしてビューを開き、[Advanced settings] (詳細設定)タブをクリックします。
  3. ビューでlog4jLevelチェックボックスをオンにして、リストから[Error] (エラー)を選択します。
  4. F6を押してこのジョブを実行します。

タスクの結果

[Run] (実行)ビューのコンソールで、実際のラベルと共に分類結果を確認できます。

同じコンソールで、計算されたスコアを読み取ることもできます。

スコアは、モデルの品質が良いことを示しています。ただし、tRandomForestModelで使われるパラメーターの調整を続行し、新しいパラメーターを使ってモデル作成ジョブを実行し、モデルの新しいバージョンを取得して評価することで、モデルを拡張できます。

このページは役に立ちましたか?

このページまたはコンテンツにタイポ、ステップの省略、技術的エラーなどの問題が見つかった場合はお知らせください。