TabPFN: A revolution in Machine Learning with tabular data
Subscribe to my newsletter to hear about my latest posts. No spam, I promise.
Only three times in my career did I feel “aha” moments that forever shifted the understanding I had of my job as a Data Scientist. The first was when I used a Genetic Algorithm to optimise a maintenance staff roster; the second was when ChatGPT generated all the code I needed to train and evaluate a Machine Learning model. The last one was today, when I used Prior Labs’ TabPFN to generate predictions on a tabular dataset without any retraining.
How Machine Learning Traditionally Works
Before diving into this technological breakthrough, let’s slow down and remind ourselves how traditional Machine Learning with tabular data works. For a quick non-technical introduction, you can refer to my post on the topic.
The general idea is simple: we create a dataset of input/output pairs, and train a model to predict the output from an input. Here are some examples of input/output pairs:
- Characteristics of a property (square metres, number of rooms, neighbourhood) and property price
- Dimensions of cell nuclei and tumour diagnosis (malignant/benign)
In the first case, the model would be trained to predict a price from the features of a property. In the second, a model would be trained to predict the diagnosis of a tumour using dimensions of its cell nuclei.
From this collection of training data, these models learn the relationship between input and the output. For example, these models could learn that larger properties are more expensive, or that some neighbourhoods are cheaper than others. This is a fascinating topic, and a whole other blog post.
The Data Scientist’s job
Data Scientists would consolidate a collection of input/output pairs, also called dataset, process the data to convert each row into a list of numbers, and finally, train many different types of Machine Learning models and pick the best one.
When choosing between different prediction models, how would you pick the best one? The most accurate model is the model whose predictions are the closest to reality. To evaluate how each model performs, we set aside a fraction of our dataset as a test set. The model will never “see” these data points.
Once the model is trained on the remaining dataset, we use it to generate predictions using the test set. For these observations; e.g., collection of properties or cell nuclei dimensions in our examples, we have both the model predictions and ground truth; e.g., the actual price of the property or the actual diagnosis of the tumour.
We can then select the model that generates the predictions closest to the ground truth.
AutoML: Searching to Learn
This process of testing different modelling approaches can be time and compute consuming. For that reason, several research teams started developing algorithms that would efficiently search the space of possible models.
The way this can be done is that training a Machine Learning model is an optimisation problem. The goal is to maximise predictive accuracy. The decision variable; i.e., parameters that can be changed, are Machine Learning model types and hyperparameters. Hyperparameters define the structure or training algorithm of the model; this term is not directly relevant to this article.
There were several challenges with AutoML:
- Compute-intensive process
- Difficult to understand optimisation methods
- AutoML performance was comparable with a type of ML model (Gradient Boosting)
- It is much simpler to deal with a single model than with an AutoML algorithm
In summary, AutoML was an optimisation algorithm to find the best Machine Learning models for a single dataset.
TabPFN: Learning to Learn
Instead of training a model for every new problem, can you train a model over a (very) large number of datasets, so that it learns to predict anything, without needing to be retrained?
This is the crazy idea that got the TabPFN research started. TabPFN stands for “Tabular Prior-data Fitted Network”. This is a mouthful, the following paragraphs should make this a bit clearer.
This is a major change in perspective, instead of treating each row (individual property or tumour) as a training sample, TabPFN treats each dataset as a training sample. This Transformer-based model was trained on millions of synthetic datasets, learning to generate predictions on any new dataset, without retraining. Just like GPT-4 can write code or poems, TabPFN can output tabular data predictions out of the box!
From a user perspective, this means that you can now send a dataset to the TabPFN model, and receive predictions within seconds - without the need to train anything. As shown in the original open-access paper, this approach outperforms all existing approaches both in terms of accuracy and computational costs. It can predict better and cheaper than models trained on that specific dataset. I tested it myself on a few datasets and was impressed by the results.
Going back to our example, this means that TabPFN model can predict both new property prices and tumour diagnosis, without any retraining. It does so very fast, and with a high degree of accuracy.
The following table compares the three approaches mentioned above
Approach | What it Learns From | Training Needed | Speed | Flexibility |
---|---|---|---|---|
Traditional ML | One dataset | Yes | Slow | One task at a time |
AutoML | One dataset (many models) | Yes (many times) | Very slow | One task at a time |
TabPFN | Millions of datasets | No | Seconds | Any tabular task |
Trying it out yourself
It took me only 15 minutes to get TabPFN to work, both locally and with the Prior Labs API. If you have a spreadsheet or csv lying around—be it sales data or lab experiment results—try running it through TabPFN. I was impressed by the model’s predictive accuracy.
As this is a rapidly changing tool, I will not provide instructions or code examples on this static page. Please refer to the TabPFN and TabPFN-Client Github repositories for this, the READMEs are self-explanatory.
Final Thoughts
The pace of progress has been incredible. Three years ago, the first version of TabPFN was only a proof of concept that could only handle datasets of 1,000 rows. The latest version can now generate predictions for datasets of 10,000 rows, and deal with missing values, uninformative columns and outliers. In the next iterations, it will most likely scale to millions of rows, when it could be used for large industry Data Science applications.
If this works, TabPFN has the potential to revolutionise the field of Data Science. This may solve the problem of tabular data Machine Learning predictions. I am very excited to see the future of this technology.