A boxer who only punches a bag will fail in the ring, and an ML model that only learns with clean data will fail in production. We need to let our model get punched in the face in training if we want it to perform well when distributions drift.
The Problem
ML systems behave poorly when the production data distribution differs from the training data distribution.
Suppose our model consumes features that are served by a feature store that is usually fully operational. If this feature store suddenly fails under heavy load most models will fail as well.
Next, suppose an upstream model extracts entities from images or text which our model consumes as categorical features. If this upstream model is kept constant while we train our model but retrained and relaunched after our model is in production our model’s performance may degrade.
Data Distribution Drift
Given the modeling task “predict Y given X” there are two distinct types of data distribution drift to be aware of.
The first is covariate drift, or changes in P(X). This can cause samples to move from a region of X where our model performs well to a region where our model performs poorly. We may see covariate drift when a large new customer is onboarded to an enterprise software product. The best way to handle covariate drift is to retrain and relaunch models automatically.
The second is concept drift, or changes in P(Y|X). Both the outage and upstream model retraining circumstances described above are instances of concept drift.
Managing Concept Drift In Production
An ML model is a reflection of the task we train it to solve. By cleverly introducing noise to the training process we can build models that perform well even during software incidents.
Suppose one of the features that our model expects tracks whether an attribute matches any of the categories in a configuration file (e.g. phrase matches, list of reserved usernames, etc). As users update this file we expect that the joint distribution of this feature and the model label Y will change as well. This kind of gradual concept drift is extremely common.
We want our ML model to “understand” that this distribution might change. We can accomplish this with a simple recipe:
- Train our model on logged features over a long period of time
- Automatically retrain and redeploy our model as frequently as possible
By forcing our model to perform well on samples hydrated from both the current logic of this feature and previous iterations of this feature we can build a model that is more resilient to data drift. We can evaluate the effectiveness of this technique by monitoring the model’s performance over long periods of time on logged predictions.
However, this strategy may not be enough to handle rare and sudden drift events like feature outages. We can force our model to handle these cases by applying the following transformation to the training dataset:
- Identify the K features that are at risk of a production outage
- Copy N samples randomly from the training dataset
- For each copy replace one or more of the K features with its default value
- Add the copies to the training dataset
We can evaluate the effectiveness of this technique by applying the same treatment to the testing dataset and evaluating the model’s performance on the corrupted testing samples.
Closing Thoughts
Every software system experiences incidents. Service outages, data pipeline delays, sudden load, and hundreds of other risks threaten uptime and damage the user experience. Mature development teams plan for these incidents and build software that adapts to unexpected changes in system behavior and availability.
Unfortunately, this kind of risk mitigation is notoriously difficult in an ML system. Small changes to model inputs can cause large and unexpected changes to model outputs. Incidents that touch ML systems therefore have a larger blast radius and longer recovery times. This critical vulnerability has slowed the adoption of machine learning technologies in safety critical applications.
In order for machine learning to continue to drive impact in new applications we will need to address this problem directly. I’m looking forward to seeing more research on the development of resilient machine learning techniques.