Production ML Pipeline with TFX: Data Validation, Training, Evaluation, and Serving

Build a production ML pipeline with TFX components — TFDV for data validation, Transform for feature preprocessing, Trainer with Keras, Evaluator with fairness metrics, and Pusher to TF Serving.

Your ML pipeline is a series of Python scripts with no data validation, no lineage tracking, and manual deployment. TFX fixes all of this. It’s the difference between a Jupyter notebook that works on your laptop and a system that reliably serves predictions to millions while automatically detecting when your training data starts to drift into nonsense. This isn't about building a model; it's about building a factory for models. We're going to build a production pipeline from raw data to a served TensorFlow Lite model, with validation gates at every step, so you can stop babysitting your Python scripts and start trusting your ML system.

Why Your Ad-Hoc Pipeline is a Time Bomb

Before we dive into components, let's diagnose the patient. You have a train.py, maybe an evaluate.py, and a serve.py you copied from a blog post. Your data validation is print(df.describe()). Model deployment is a manual scp to a server, followed by frantic ssh sessions when predictions go weird. You have no idea if the dip in accuracy is due to data drift, a bad feature transformation, or a cosmic ray.

This is why TFX powers ML pipelines for 50%+ of Google's production ML models (TFX paper 2025 update). It's not a fancy library; it's a framework that forces discipline. It tracks every artifact (data, schema, model), validates every input, and makes your pipeline reproducible and automatable. The goal isn't to run it once on your machine. The goal is to run it daily on Kubeflow or Vertex AI, where it ingests new data, validates it, retrains if necessary, and pushes a new model without you lifting a finger.

TFX Architecture: The Orchestra and Its Conductor

Think of a TFX pipeline as an orchestra. Each musician is a Component (like ExampleGen, StatisticsGen, Trainer). The sheet music is the Pipeline Definition, which scores who plays when. The concert hall's archive is the Metadata Store, which keeps a immutable record of every performance: what data was used, what hyperparameters, what the resulting model accuracy was.

The magic is in the artifacts. A component doesn't just pass a Pandas DataFrame to the next. It produces and consumes typed artifacts (e.g., Examples, Schema, Model) that are stored and versioned. The Metadata Store is the spine of this system. It's typically a MLMD (ML Metadata) backend, like SQLite for local runs or a cloud SQL database for production. This lineage lets you answer critical questions: "Which training run produced the model currently in production?" and "What was the distribution of the age feature in that run?"

Here’s a minimal pipeline definition to set the stage. We'll flesh out each component next.


import tfx.v1 as tfx
from tfx.proto import example_gen_pb2

def create_pipeline(pipeline_name: str,
                    pipeline_root: str,
                    data_root: str,
                    serving_model_dir: str) -> tfx.dsl.Pipeline:
    # 1. Bring data into the pipeline.
    example_gen = tfx.components.CsvExampleGen(
        input_base=data_root,
        output_config=example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=8),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2)
            ]))
    )

    # 2. Generate statistics about the data.
    statistics_gen = tfx.components.StatisticsGen(
        examples=example_gen.outputs['examples']
    )

    # 3. Infer a schema from the statistics.
    schema_gen = tfx.components.SchemaGen(
        statistics=statistics_gen.outputs['statistics']
    )

    # We'll add Trainer, Evaluator, and Pusher in later sections.
    components = [example_gen, statistics_gen, schema_gen]

    return tfx.dsl.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True  # Game-changer for iteration speed.
    )

TFDV: Catching Data Anomalies Before They Become Model Anomalies

TensorFlow Data Validation (TFDV) is your pipeline's immune system. StatisticsGen and SchemaGen are the components that use it. StatisticsGen creates a snapshot of your data's health: min/max values, mean, std, missing values, etc. SchemaGen looks at those stats and proposes a schema—a contract that defines what "normal" data looks like for each feature (type, expected range, presence).

The real power is in subsequent runs. You can validate new data against the frozen schema. Did a new categorical value appear? Did a numerical feature suddenly have a 1000x spike? TFDV will flag it as an anomaly. This catches upstream data pipeline bugs before they poison your model.

# tfdv_validation_script.py
import tensorflow_data_validation as tfdv
from tensorflow_metadata.proto.v0 import schema_pb2

# Load statistics from a previous run (or generate new ones).
train_stats = tfdv.load_statistics('path/to/train_stats.tfrecord')

# 1. Infer a schema from the training data.
schema = tfdv.infer_schema(statistics=train_stats)
# Freeze it! This is your source of truth.
tfdv.write_schema_text(schema, 'path/to/schema.pbtxt')

# 2. Later, load new data and validate it.
new_stats = tfdv.generate_statistics_from_csv('new_data.csv')
anomalies = tfdv.validate_statistics(statistics=new_stats, schema=schema)

if anomalies.anomaly_info:
    print("🚨 Data anomalies detected!")
    # You can set severity overrides or automatically adjust schema in CI.
    tfdv.display_anomalies(anomalies)
else:
    print("✅ Data passes validation.")

Real Error & Fix: You deploy a new model and get ValueError: Input 0 of layer is incompatible with the layer. The fix isn't just checking input_shape. In TFX, this is often a schema mismatch. The model expects a feature based on the schema from training run #12, but the serving system is receiving data that violates it (e.g., a string where a float is expected). The fix: Run TFDV validation on your serving logs and compare the stats/schema to your training data. Update your schema or fix your data pipeline.

TF Transform: Consistent Preprocessing from Training to Serving

Here’s the classic mistake: you normalize your training data in train.py using fit_transform, but you forget to apply the same normalization in your serving code. tf.Transform solves this by letting you define the preprocessing function once, as a TensorFlow graph. It runs this graph during training, but also exports another graph as part of your SavedModel that applies the exact same transformation at serving time.

Your preprocessing moves from being a Python function that runs before model.fit() to a declared graph that's part of the model itself. This guarantees consistency.

The Trainer: Your Keras Code, Now in a Factory

The Trainer component is where your familiar Keras model-building code lives, but it's wrapped in a TFX component that manages the input tf.data pipeline from ExampleGen, receives the schema, and outputs a SavedModel. The key is the trainer_fn function you write, which returns a tfx.components.TrainerFnArgs containing your run logic.

Let's build a trainer for a simple model, using Keras 3 with the TensorFlow backend. Remember, Keras 3 supports PyTorch, TensorFlow, and JAX backends — 3x more flexible than Keras 2 (Keras.io 2025).

# trainer_module.py
import keras
import tensorflow as tf
import tfx.v1 as tfx
from tfx.components.trainer.fn_args_utils import FnArgs

def _build_keras_model(input_shape: int) -> keras.Model:
    """A simple Keras 3 model definition."""
    inputs = keras.Input(shape=(input_shape,))
    x = keras.layers.Dense(128, activation='relu')(inputs)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(1, activation='sigmoid')(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss=keras.losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )
    return model

def _input_fn(file_pattern: str, schema: schema_pb2.Schema,
              batch_size: int = 64) -> tf.data.Dataset:
    """Creates a tf.data pipeline from TFRecord files."""
    # TFX provides utilities to parse examples according to the schema.
    parsed_features = tfx.components.util.tfxio_utils.get_parsed_feature_spec(schema)
    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern=file_pattern,
        batch_size=batch_size,
        features=parsed_features,
        reader=tf.data.TFRecordDataset,
        num_epochs=1,
        shuffle=True
    )
    # Map to (features, label). Assume label feature is named 'label'.
    def _split_features(x):
        label = x.pop('label')
        return x, label
    dataset = dataset.map(_split_features)
    # Critical for performance: prefetch.
    return dataset.prefetch(tf.data.AUTOTUNE)

def run_fn(fn_args: FnArgs):
    """The function executed by the Trainer component."""
    # 1. Get the training and eval data paths.
    train_dataset = _input_fn(fn_args.train_files, fn_args.schema, batch_size=64)
    eval_dataset = _input_fn(fn_args.eval_files, fn_args.schema, batch_size=64)

    # 2. Build and train the model.
    model = _build_keras_model(input_shape=10)  # Assume 10 features.
    # Use tf.data's performance magic.
    # tf.data with prefetch(AUTOTUNE): 4.2x training throughput vs eager dataset loading on ImageNet
    model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=fn_args.train_steps or 10,
        callbacks=[
            keras.callbacks.TensorBoard(log_dir=fn_args.model_run_dir),
            # Add early stopping, etc.
        ]
    )

    # 3. Save the model in the SavedModel format.
    model.save(fn_args.serving_model_dir, save_format='tf')

You then plug this module into the Trainer component. TFX runs it, feeding in the correct file paths and schema.

The Evaluator: Not Just One Accuracy Number

The Evaluator component does more than compute loss on the eval set. It uses TensorFlow Model Analysis (TFMA) to compute metrics sliced across different segments of your data (e.g., accuracy for users in the US vs. EU, or for different age groups). This is where you catch fairness issues before deployment. It can also compare your newly trained model against a baseline (the currently served model) and decide if it's "good enough" to push.

You configure this with a slicing spec and thresholds in an eval_config. The Evaluator outputs an Evaluation artifact. The next component, the Pusher, can be configured to only act if this evaluation passes your criteria.

Pusher & Serving: Conditional Deployment to Production

The Pusher is the final gatekeeper. It checks the blessing output from the Evaluator. If the new model is blessed (e.g., its accuracy is higher than a baseline and passes all fairness thresholds), the Pusher deploys the SavedModel.

Where does it push? To a serving_model_dir on disk, which could be a cloud storage bucket (like gs://your-bucket/models/) that your serving system watches. For production, you'll use TensorFlow Serving or a cloud endpoint.

TensorFlow Serving: 15,000 req/s on 4-core CPU instance for ResNet-50 inference (batch=1). That's your serving performance benchmark. For edge deployment, you'd extend the pipeline with a ModelOptimization component to convert the SavedModel to TensorFlow Lite, which is deployed on 6B+ devices as of 2025 (Google I/O 2025).

Deployment TargetFormatOptimizationTypical Use CasePerformance Gain
Cloud/Server (CPU)SavedModel--TensorFlow ServingBaseline (15k req/s)
Cloud/Server (NVIDIA GPU)SavedModelTF-TRTHigh-throughput servers2–5x faster inference vs TF native
Mobile (Android/iOS)TensorFlow LiteFP16/INT8 QuantizationOn-device inference3x smaller, 2x faster vs FP32 on Pixel 8
Web BrowserTensorFlow.jsWebGL/WebGPUClient-side ML30M+ inference calls/month

Real Error & Fix: You try to load your SavedModel in TF Serving and get SavedModel load error: No such attribute 'call'. This often means you saved a Keras model in the .keras format (default in model.save()) instead of the SavedModel format. The fix: Save with model.save(..., save_format='tf') explicitly, or use tf.saved_model.save() for full control. TFX's Trainer does this correctly by default.

From Local Runs to Kubeflow Pipelines

Running pipeline.create_pipeline().run() locally with the LocalDagRunner is great for development. Production runs on Kubeflow Pipelines (KFP) or Vertex AI Pipelines. You compile your pipeline definition to a KFP YAML file, and then submit it to a Kubernetes cluster. Each component runs in its own container, scaling with your workload. The Metadata Store is now a cloud SQL database, and artifacts live in cloud storage. This is the true production environment.

Next Steps: From Pipeline to MLOps Flywheel

You now have a pipeline that ingests, validates, transforms, trains, evaluates, and conditionally deploys. But this is just the core loop. The next steps are what turn this into an MLOps system:

  1. Triggering: Set up a cron job or event-driven trigger (e.g., when new data lands in a bucket) to run your pipeline automatically on Kubeflow.
  2. Monitoring: Use the Metadata Store lineage to track model performance over time. Feed your production serving logs back into TFDV to continuously monitor for data drift and concept drift.
  3. Hyperparameter Tuning: Integrate a component like Tuner (using KerasTuner or Katib) to search for better hyperparameters on each run.
  4. Explainability: Add a component that uses tools like SHAP or integrated gradients to generate feature attributions for your model's predictions, stored as artifacts.
  5. Multi-Platform Deployment: Extend your Pusher to not only push a SavedModel for TF Serving but also automatically generate and version a TensorFlow Lite model for mobile and a TensorFlow.js model for web, as shown in the table above.

The goal is a self-improving system. New data flows in, the pipeline validates it, retrains if beneficial, and deploys a better model, all while maintaining a complete audit trail. Your scripts are now components. Your manual steps are now automated gates. Your model is no longer a fragile artifact, but the continuously refined output of a reliable factory. Stop babysitting, and start scaling.