Per Entity Training

The aim of this pipeline example is to demonstrate per entity training in Beam. Per entity training refers to the process of training a machine learning model for each individual entity, rather than training a single model for all entities. In this approach, a separate model is trained for each entity based on the data specific to that entity. Per entity training can be beneficial in the following scenarios:


This example uses Adult Census Income dataset. The dataset contains information about individuals, including their demographic characteristics, employment status, and income level. The dataset includes both categorical and numerical features, such as age, education, occupation, and hours worked per week, as well as a binary label indicating whether an individual’s income is above or below 50,000 USD. The primary goal of this dataset is to be used for classification tasks, where the model will predict whether an individual’s income is above or below a certain threshold based on the provided features.The pipeline expects the CSV file as an input. This file can be downloaded from here.

Run the Pipeline

First, install the required packages apache-beam==2.44.0, scikit-learn==1.0.2 and pandas==1.3.5. You can view the code on GitHub. Use python --input path/to/

Train the pipeline

The pipeline can be broken down into the following main steps:

  1. Read the data from the provided input path.
  2. Filter the data based on some criteria.
  3. Create key based on education level.
  4. Group dataset based on the key generated.
  5. Preprocess the dataset.
  6. Train model per education level.
  7. Save the trained models.

The following code snippet contains the detailed steps:

    with beam.Pipeline(options=pipeline_options) as pipeline:
        _ = (
            pipeline | "Read Data" >>
            | "Split data to make List" >> beam.Map(lambda x: x.split(','))
            | "Filter rows" >> beam.Filter(custom_filter)
            | "Create Key" >> beam.ParDo(CreateKey())
            | "Group by education" >> beam.GroupByKey()
            | "Prepare Data" >> beam.ParDo(PrepareDataforTraining())
            | "Train Model" >> beam.ParDo(TrainModel())
            "Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink()))