Skip to content

How to choose

Here is a guide to choosing the right type of ML model for your categorization task.
First the features and indications for use of the model types are summarized, then you find the suggested model types based on data and goals.

Model types summary

Support Vector Machines (SVM)

  • SVM highlights:

    • Works well with small and large datasets
    • Low level of explainability, but generally has a higher level of accuracy
    • Slower to train and more difficult to implement (due to parameter tuning requirements)
  • Use SVM when:

    • You are unsure of how many classes the data set is divided into (no clear criteria)
    • You have a smaller labeled training set, generally (SVM can function well with a smaller annotated training set)
  • Which type of SVM?

    • Linear SVM: for binary classification with linearly-separable data
    • Probabilistic SVM: for multi-label classification
    • Custom Kernel SVM: for classifying data that is not linearly-separable
    • SGD: for very large datasets to minimize expected loss

Decision trees ensemble

  • Decision trees ensemble highlights:

    • Works well with large datasets
    • Generally average levels of explainability and accuracy
    • Slow to train and generally more difficult to implement (due to high convergence time and repetition)
  • Use decision trees ensemble when:

    • You have large datasets
    • You are concerned by your model being influenced by outliers in the data
    • You have collinear features
    • Data is not linear
  • Which decision trees ensemble:

    • Random Forest: the simplest type
    • GBoost: if tuned correctly, can offer better results than Random Forest
    • XGBoost: with very large training set and small amount of features

Linear (Logistic Regression)

  • Logistic Regression highlights:

    • Works well with small and large datasets
    • High level of explainability, but generally lower level of accuracy
    • Quick to train and relatively easy to implement
  • Use Logistic Regression when:

    • It's a binary classification problem
    • Your dependent variable (target) is categorical

Naive Bayes

  • Naive Bayes highlight:

    • Work well with small datasets
    • Higher level of explainability than other algorithms, but lower level of accuracy than most
    • Quick to train and relatively easy to implement
  • Use Naive Bayes when:

    • The training set is relatively small.
    • Training set and test set are well balanced (classes equally represented in data).
    • Documents are mostly equally sized
  • Which Naive Bayes?

    • Multinomial Naïve Bayes: for cases in which all the conditions above are satisfied
    • Complement Naive Bayes: for imbalanced datasets

Selection based on data and goals

Based on the size and availability of training data

  • For small training sets with a higher number of text features than training inputs, a model with high bias/low variance may be best:

    • Linear regression
    • Naive Bayes
    • Linear SVM
  • For large training sets with a higher number of training inputs than text features, a model with low bias/high variance may be best:

    • Decision tree-based models
    • Kernel SVM

Based on the tradeoff between accuracy and explainability

  • Top explainability:

    • Logistic Regression
  • Explainability over accuracy:

    • Decision Trees ensembles
    • Naive Bayes
  • Balance of accuracy and explainability:

    • Random Forest
  • Accuracy over explainability:

    • GBoost
    • XGBoost
  • Top accuracy:

    • SVM

Based on training speed and time

  • Quicker to train, easier to implement models:

    • Logistic Regression
    • Naive Bayes
  • Slower to train, more difficult to implement models:

    • SVM Models (due to parameter tuning)
    • Random Forest (due to high convergence time and repetition)

Based on data linearity of the data (is your data linear?)

  • If your data is linear (classes can be separated by a line):

    • Logistic Regression
    • Linear SVM
  • If your data is not linear (classes cannot be separated by a line):

    • Custom Kernel SVM
    • Random Forest
    • GBoost
    • XGBoost