top of page

Naive Bayes Classifiers (multi-label)

Train test

The accuracy of a classification model is the number of correct predictions divided by the total number of predictions. This statistic is used to describes the model's performance. To determine accuracy, it is necessary to split the dataset into train and test data. Training and testing on the same dataset will result in a well-trained classifier but the accuracy of the model will be biased. A test set that is too small may result in an unbiased but unreliable accuracy estimate for a well-trained classifier. A test set that is too large may result in an unbiased and reliable accuracy for a badly trained classifier. This is why the consideration of train and test sets is important to consider. Common train to test ratios include 8:2, 7:3, and 6:4.


Below, split the data into training and test sets with a 7:3 ratio, train the model on the training dataset, and then check the accuracy on the test dataset.


# Import train_test_split function
from sklearn.model_selection import train_test_split

# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(wine.data, wine.target, test_size=0.3,random_state=109) # 70% training and 30% test

# Import Gaussian Naive Bayes model
from sklearn.naive_bayes import GaussianNB

# Create a Gaussian Classifier
gnb = GaussianNB()

# Train the model using the training sets
gnb.fit(X_train, y_train)

# Predict the response for test dataset
y_pred = gnb.predict(X_test)

Predict a single output with the model.


new_datapoint = wine.data[13, :]

# Predict Output
predicted= gnb.predict([new_datapoint])
predict_probability = gnb.predict_proba([new_datapoint])
print("Predicted Value:", predicted, " with ", predict_probability)

Predicted Value: [0] with [[1.00000000e+00 3.95539110e-12 6.25476571e-44]]

Print the accuracy of a classification model using the accuracy_score method.


# Import scikit-learn metrics module for accuracy calculation
from sklearn import metrics

# Model Accuracy, how often is the classifier correct?print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

Accuracy: 0.9074074074074074

Print the confusion matrix to see which classes are commonly misclassified. The columns represent the actual values and the rows represent the classified values. So the diagonal numbers show the number of correctly classified data.


# Import scikit-learn metrics module for confusion matrix
from sklearn import metrics

# Which classes are commonly misclassified?print(metrics.confusion_matrix(y_test, y_pred, labels=[0, 1, 2]))

[[20 1 0] [ 2 15 2] [ 0 0 14]]
bottom of page