Machine learning has emerged as a powerful tool for extracting patterns, making predictions, and solving complex problems across various domains, from structured data analysis, natural language processing, and computer vision to generative AI. At its core, all machine learning algorithms can be described as an optimization process. The models take input data (and times output data) to learn complex functions that minimize or maximise a cost or objective function with additional constraints on the model parameters. The choice of the objective function plays a crucial role in shaping the behavior and performance of machine learning models.
This article presents objective functions used within the realm of supervised learning techniques of regression on structured data. By understanding the nuances and trade-offs associated with different objective functions, practitioners and researchers can make informed decisions while designing and training their machine-learning models.
Regression is a supervised learning technique to predict continuous numerical values based on input data. The objective Functions used within regression are as follows:
The plot below shows the magnitude of the objective functions as the difference between predicted and actual values varies. This illustrates how the objective function behaves near zero values and how outliers are treated within each of the objective function.
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import SGDRegressor
# Mean Squared Error (MSE)
mse_model = LinearRegression()
# Mean Absolute Error (MAE)
mae_model = SGDRegressor(loss='epsilon_insensitive', epsilon=0.0, random_state=42)
# Mean Squared Logarithmic Error (MSLE)
msle_model = MLPRegressor(loss='squared_log')
# Huber Loss
huber_model = MLPRegressor(loss='huber')
# Log-cosh
logcosh_model = MLPRegressor(loss='logcosh', random_state=4
# Mean Squared Error (MSE)
mse_model <- glm(y ~ X, family = gaussian(link = "identity"))
# Mean Absolute Error (MAE)
mae_model <- glm(y ~ X, family = gaussian(link = "identity"), loss = "absolute")
# Mean Squared Logarithmic Error (MSLE)
msle_model <- glm(y ~ log(X + 1), family = gaussian(link = "identity"), loss = "squared")
# Huber Loss
huber_model <- glm(y ~ X, family = gaussian(link = "identity"), loss = "huber")
# Log-cosh
logcosh_model <- glm(y ~ X, family = gaussian(link = "identity"), loss = "log-cosh")
Selecting the right loss function for your data and predictive needs is a multi-faceted decision. To make the best choice, there are several important factors to consider. First and foremost, you must determine the objective of your prediction and model evaluation. Depending on your specific needs, you may be aiming for the most accurate prediction or there may be another downstream application for your model. In order to determine which objective function performs best for your dataset and needs, cross-validation is an essential tool.
Another crucial factor to consider is near-zero behavior. You must determine the appropriate penalty to assign to your model when error magnitudes are low. Additionally, if your data set includes a large number of outliers, a more robust objective function is necessary to effectively handle these outliers.
Finally, it’s important to keep in mind the assumptions made by different loss functions. For instance, RMSE is optimal for normal (Gaussian) errors, while MAE is best suited for Laplacian errors- [1]. By considering all of these important factors, you can confidently choose the right loss function to ensure accurate predictions and an effective model.
References:
Hodson, T.O., 2022. Root-mean-square error (RMSE) or mean absolute error (MAE): when to use them or not. Geoscientific Model Development, 15(14), pp.5481–5487.
Scikit Learn MLPRegressor — https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html
Keras Regression Losses — https://keras.io/api/losses/regression_losses/
R Generalized Linear Models — https://www.rdocumentation.org/packages/stats/versions/3.6.2/topics/glm
Also published here.