Accéder au contenu principal Passer au contenu complémentaire

Évaluation du modèle de classification

Après avoir créé un modèle de classification, vous pouvez évaluer sa qualité.

Relier les composants

Procédure

  1. Dans la perspective Integration du Studio Talend, créez un autre Job vide Spark Batch, nommé classify_and_evaluation par exemple, depuis le nœud Job Designs dans la vue Repository.
  2. Dans l'espace de modélisation graphique, saisissez le nom des composants à utiliser et sélectionnez-les dans la liste qui s'affiche.
    Dans ce Job, les composants sont : un tHDFSConfiguration, un tFileInputDelimited, un tPredict, un tReplicate, un tJava, un tFilterColumns et un tLogRow.
  3. Reliez les composants à l'aide de liens Row > Main, sauf le tHDFSConfiguration, comme affiché dans l'image ci-dessus.
    Job de sept composants utilisant le composant tPredict.
  4. Double-cliquez sur le tHDFSConfiguration pour ouvrir sa vue Component et configurez-le comme expliqué précédemment dans le scénario.

Charger le jeu de données de test dans le Job

Procédure

  1. Double-cliquez sur le tFileInputDelimited pour ouvrir sa vue Component.
  2. Cochez la case Define a storage configuration component et sélectionnez le composant tHDFSConfiguration à utiliser.
    Le tFileInputDelimited utilise cette configuration pour accéder à l'ensemble d'apprentissage à utiliser.
  3. Cliquez sur le bouton [...] à côté du champ Edit schema pour ouvrir l'éditeur du schéma.
  4. Cliquez cinq fois sur le bouton [+] pour ajouter cinq lignes et, dans la colonne Column, renommez ces lignes reallabel, sms_contents, num_currency, num_numeric et num_exclamation, respectivement.
    Les colonnes reallabel et sms_contents contiennent les données brutes composées, respectivement, de SMS dans la colonne sms_contents et de libellés indiquant qu'un message est un spam dans la colonne reallabel.
    Les autres colonnes sont utilisées pour contenir les caractéristiques ajoutées aux jeux de données brutes, comme expliqué précédemment dans le scénario. Elles contiennent le nombre de symboles monétaires, le nombre de valeurs numériques et le nombre de points d'exclamation trouvés dans chaque SMS.
  5. Dans la colonne Type, sélectionnez Integer pour les colonnes num_currency, num_numeric et num_exclamation.
  6. Cliquez sur OK pour valider ces modifications.
  7. Dans le champ Folder/File, saisissez le répertoire où est stocké le test à utiliser.
  8. Dans le champ Field separator, saisissez \t, séparateur utilisé par les jeux de données que vous pouvez télécharger et utiliser dans ce scénario.

Appliquer le modèle de classification

Procédure

  1. Double-cliquez sur le tPredict pour ouvrir sa vue Basic settings.
  2. Dans la liste Model Type, sélectionnez Random Forest Model.
  3. Sélectionnez le bouton radio Model on filesystem et saisissez le répertoire dans lequel est stocké le modèle de classification à utiliser.
    Le composant tPredict contient une colonne en lecture seule nommée label, dans laquelle le modèle fournit les classes à utiliser dans le processus de classification. La colonne reallabel récupérée du schéma d'entrée contient les classes auxquelles chaque message appartient. Le modèle sera évalué en comparant le libellé de chaque message par rapport au libellé déterminé par le modèle.

Répliquer les résultats de classification

Procédure

  1. Double-cliquez sur le tReplicate pour ouvrir sa vue Component.
  2. Laissez les autres paramètres par défaut.

Filtrer les résultats de classification

Procédure

  1. Double-cliquez sur le tFilterColumns pour ouvrir sa vue Component.
  2. Cliquez sur le bouton [...] à côté du champ Edit schema pour ouvrir l'éditeur du schéma.
  3. Du côté de la sortie, cliquez trois sur le bouton [+] pour ajouter trois lignes et, dans la colonne Column, renommez-les respectivement reallabel, label et sms_contents. Elles reçoivent des données des colonnes d'entrée utilisant les mêmes noms.
  4. Cliquez sur OK pour valider ces modifications et acceptez la propagation proposée par la boîte de dialogue qui s'ouvre.

Écrire le programme d'évaluation dans un tJava

Procédure

  1. Double-cliquez sur le tJava pour ouvrir sa vue Component.
  2. Cliquez sur le bouton Sync columns pour vous assurer que le tJava récupère le schéma répliqué du tPredict.
  3. Cliquez sur l'onglet Advanced settings pour ouvrir sa vue.
  4. Dans le champ Classes, saisissez le code pour définir les classes Java à utiliser afin de vérifier si les libellés de classe prédits correspondent aux libellés réels des classes,
    spam pour les messages indésirables et ham pour les messages normaux.
    Dans ce scénario, row7 est l'ID de la connexion entre le tPredict et le tReplicate et contient les résultats de classification à envoyer aux composants suivants. row7Struct est la classe Java du RDD pour les résultats de classification. Dans votre code, vous devez remplacer row7, utilisé seul ou au sein de row7Struct, par l'ID de la connexion utilisée dans votre Job.
    Les noms des colonnes, comme reallabel ou label ont été définis dans l'étape précédente lors de la configuration des différents composants. Si vous les avez nommées différemment, gardez la cohérence pour utilisation dans votre code.
    public static class SpamFilterFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return row7.reallabel.equals("spam");
    	}
    	
    }
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    public static class TrueNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return (row7.label.equals("ham") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class TruePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// true positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("spam"));
    	}
    	
    }
    
    public static class FalseNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class FalsePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("ham") && row7.reallabel.equals("spam"));
    	}
    	
    }
  5. Cliquez sur l'onglet Basic settings et, dans le champ Code, saisissez le code à utiliser pour calculer le score de précision et le coefficient de corrélation de Matthews (Matthews Correlation Coefficient, MCC) du modèle de classification.
    Pour une explication générale relative à ce coefficient, consultez l'article Wikipédia https://en.wikipedia.org/wiki/Matthews_correlation_coefficient (en anglais).
    long nbTotal = rdd_tJava_1.count();
    
    long nbSpam = rdd_tJava_1.filter(new SpamFilterFunction()).count();
    
    long nbHam = nbTotal - nbSpam;
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    long tn = rdd_tJava_1.filter(new TrueNegativeFunction()).count();
    
    long tp = rdd_tJava_1.filter(new TruePositiveFunction()).count();
    
    long fn = rdd_tJava_1.filter(new FalseNegativeFunction()).count();
    
    long fp = rdd_tJava_1.filter(new FalsePositiveFunction()).count();
    
    double mmc = (double)(tp*tn -fp*fn) / java.lang.Math.sqrt((double)((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)));
    
    System.out.println("Accuracy:"+((double)(tp+tn)/(double)nbTotal));
    System.out.println("Spams caught (SC):"+((double)tp/(double)nbSpam));
    System.out.println("Blocked hams (BH):"+((double)fp/(double)nbHam));
    System.out.println("Matthews correlation coefficient (MCC):" + mmc); 

Configurer la connexion à Spark

Pourquoi et quand exécuter cette tâche

Répétez les opérations décrites ci-dessus. Consultez Sélectionner le mode Spark.

Exécuter le Job

Procédure

  1. Le composant tLogRow est utilisé pour afficher le résultat de l'exécution de ce Job.
    Si vous souhaitez configurer le mode d'affichage dans sa vue Component, double-cliquez sur le composant tLogRow pour ouvrir la vue Component et dans la zone Mode, sélectionnez l'option Table (print values in cells of a table).
  2. Si vous souhaitez afficher uniquement le niveau d'information des erreurs pour les logs de Log4j dans la console de la vue Run, cliquez sur l'onglet Run pour ouvrir cette vue, puis cliquez sur l'onglet Advanced settings.
  3. Cochez la case log4jLevel et sélectionnez Error dans la liste.
  4. Appuyez sur F6 pour exécuter le Job.

Résultats

Dans la console de la vue Run, vous pouvez lire les résultats de classification et les libellés utilisés :

Vous pouvez également voir les scores calculés dans la même console :

Les scores montrent la bonne qualité du modèle, mais vous pouvez toujours l'améliorer en continuant la personnalisation des paramètres utilisés dans le tRandomForestModel et en exécutant le Job de création de modèle avec de nouveaux paramètres pour obtenir et évaluer de nouvelles versions du modèle.

Cette page vous a-t-elle aidé ?

Si vous rencontrez des problèmes sur cette page ou dans son contenu – une faute de frappe, une étape manquante ou une erreur technique – faites-le-nous savoir.