Online Clustering Example

The online clustering example demonstrates how to set up a real-time clustering pipeline that can read text from Pub/Sub, convert the text into an embedding using a language model, and cluster the text using BIRCH.

Dataset for Clustering

This example uses a dataset called emotion that contains 20,000 English Twitter messages with 6 basic emotions: anger, fear, joy, love, sadness, and surprise. The dataset has three splits: train, validation, and test. Because it contains the text and the category (class) of the dataset, it’s a supervised dataset. To access this dataset, use the Hugging Face datasets page.

The following text shows examples from the train split of the dataset:

TextType of emotion
im grabbing a minute to post i feel greedy wrongAnger
i am ever feeling nostalgic about the fireplace i will know that it is still on the propertyLove
ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funnyFear
on a boat trip to denmarkJoy
i feel you know basically like a fake in the realm of science fictionSadness
i began having them several times a week feeling tortured by the hallucinations moving people and figures sounds and vibrationsFear

Clustering Algorithm

For the clustering of tweets, we use an incremental clustering algorithm called BIRCH. It stands for balanced iterative reducing and clustering using hierarchies, and it is an unsupervised data mining algorithm used to perform hierarchical clustering over particularly large datasets. An advantage of BIRCH is its ability to incrementally and dynamically cluster incoming, multi-dimensional metric data points in an attempt to produce the best quality clustering for a given set of resources (memory and time constraints).

Ingestion to Pub/Sub

The example starts by ingesting the data into Pub/Sub so that we can read the tweets from Pub/Sub while clustering. Pub/Sub is a messaging service for exchanging event data among applications and services. Streaming analytics and data integration pipelines use Pub/Sub to ingest and distribute data.

You can find the full example code for ingesting data into Pub/Sub in GitHub

The file structure for the ingestion pipeline is shown in the following diagram:

write_data_to_pubsub_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/utils.py contains the code for loading the emotion dataset and two beam.DoFn that are used for data transformation.

pipeline/options.py contains the pipeline options to configure the Dataflow pipeline.

config.py defines some variables that are used multiple times, like GCP PROJECT_ID and NUM_WORKERS.

setup.py defines the packages and requirements for the pipeline to run.

main.py contains the pipeline code and some additional function used for running the pipeline.

Run the Pipeline

First, install the required packages.

  1. Locally on your machine: python main.py
  2. On GCP for Dataflow: python main.py --mode cloud

The write_data_to_pubsub_pipeline contains four different transforms:

  1. Load the emotion dataset using Hugging Face datasets (for simplicity, we take samples from three classes instead of six).
  2. Associate each piece of text with a unique identifier (UID).
  3. Convert the text into the format that Pub/Sub expects.
  4. Write the formatted message to Pub/Sub.

Clustering on Streaming Data

After ingesting the data to Pub/Sub, examine the second pipeline, where we read the streaming message from Pub/Sub, convert the text to a embedding using a language model, and cluster the embedding using BIRCH.

You can find the full example code for all the steps mentioned previously in GitHub.

The file structure for clustering_pipeline is:

clustering_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/transformations.py contains the code for the different beam.DoFn that are used in the pipeline.

pipeline/options.py contains the pipeline options to configure the Dataflow pipeline.

config.py defines variables that are used multiple times, like Google Cloud PROJECT_ID and NUM_WORKERS.

setup.py defines the packages and requirements for the pipeline to run.

main.py contains the pipeline code and some additional functions used for running the pipeline.

Run the Pipeline

Install the required packages and push data to Pub/Sub.

  1. Locally on your machine: python main.py
  2. On GCP for Dataflow: python main.py --mode cloud

The pipeline can be broken down into the following steps:

  1. Read the message from Pub/Sub.
  2. Convert the Pub/Sub message into a PCollection of dictionaries where the key is the UID and the value is the Twitter text.
  3. Encode the text into transformer-readable token ID integers using a tokenizer.
  4. Use RunInference to get the vector embedding from a transformer-based language model.
  5. Normalize the embedding for clustering.
  6. Perform BIRCH clustering using stateful processing.
  7. Print the texts assigned to clusters.

The following code shows the first two steps of the pipeline, where a message from Pub/Sub is read and converted into a dictionary.

    docs = (
        pipeline
        | "Read from PubSub"
        >> ReadFromPubSub(subscription=cfg.SUBSCRIPTION_ID, with_attributes=True)
        | "Decode PubSubMessage" >> beam.ParDo(Decode())
    )

The next sections examine three important pipeline steps:

  1. Tokenize the text.
  2. Feed the tokenized text to get embedding from a transformer-based language model.
  3. Perform clustering using stateful processing.

Get Embedding from a Language Model

In order cluster text data, you need to map the text into vectors of numerical values suitable for statistical analysis. This example uses a transformer-based language model called sentence-transformers/stsb-distilbert-base/stsb-distilbert-base. It maps sentences and paragraphs to a 768 dimensional dense vector space, and you can use it for tasks like clustering or semantic search.

Because the language model is expecting a tokenized input instead of raw text, start by tokenizing the text. Tokenization is a preprocessing task that transforms text so that it can be fed into the model for getting predictions.

    normalized_embedding = (
        docs
        | "Tokenize Text" >> beam.Map(tokenize_sentence)

Here, tokenize_sentence is a function that takes a dictionary with a text and an ID, tokenizes the text, and returns a tuple (text, id) and the tokenized output.

Tokenized output is then passed to the language model for getting the embeddings. For getting embeddings from the language model, we use RunInference() from Apache Beam.

    | "Get Embedding" >> RunInference(KeyedModelHandler(model_handler))

To make better clusters, after getting the embedding for each piece of Twitter text, the embeddings are normalized.

    | "Normalize Embedding" >> beam.ParDo(NormalizeEmbedding())

StatefulOnlineClustering

Because the data is streaming, you need to use an iterative clustering algorithm, like BIRCH. And because the algorithm is iterative, you need a mechanism to store the previous state so that when Twitter text arrives, it can be updated. Stateful processing enables a DoFn to have persistent state, which can be read and written during the processing of each element. For more information about stateful processing, see Stateful processing with Apache Beam.

In this example, every time a new message is read from Pub/Sub, you retrieve the existing state of the clustering model, update it, and write it back to the state.

    clustering = (
        normalized_embedding
        | "Map doc to key" >> beam.Map(lambda x: (1, x))
        | "StatefulClustering using Birch" >> beam.ParDo(StatefulOnlineClustering())
    )

Because BIRCH doesn’t support parallelization, you need to make sure that only one worker is doing all of the stateful processing. To do that, use Beam.Map to associate each text to the same key 1.

StatefulOnlineClustering is a DoFn that takes an embedding of a text and updates the clustering model. To store the state, it uses the ReadModifyWriteStateSpec state object, which acts as a container for storage.

class StatefulOnlineClustering(beam.DoFn):

    BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
    DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
    EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
    UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())

This example declares four different ReadModifyWriteStateSpec objects:

These ReadModifyWriteStateSpec objects are passed as an additional argument to the process function. When a news item comes in, we retrieve the existing state of the different objects, update them, and then write them back as persistent shared state.

def process(
    self,
    element,
    model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
    collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
    collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
    update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
    *args,
    **kwargs,
):
  """
      Takes the embedding of a document and updates the clustering model

      Args:
        element: The input element to be processed.
        model_state: This is the state of the clustering model. It is a stateful parameter,
        which means that it will be updated after each call to the process function.
        collected_docs_state: This is a stateful dictionary that stores the documents that
        have been processed so far.
        collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
        update_counter_state: This is a counter that keeps track of how many documents have been
      processed.
      """
  # 1. Initialise or load states
  clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
  collected_documents = collected_docs_state.read() or {}
  collected_embeddings = collected_embeddings_state.read() or {}
  update_counter = update_counter_state.read() or Counter()

  # 2. Extract document, add to state, and add to clustering model
  _, doc = element
  doc_id = doc["id"]
  embedding_vector = doc["embedding"]
  collected_embeddings[doc_id] = embedding_vector
  collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
  update_counter = len(collected_documents)

  clustering.partial_fit(np.atleast_2d(embedding_vector))

  # 3. Predict cluster labels of collected documents
  cluster_labels = clustering.predict(
      np.array(list(collected_embeddings.values())))

  # 4. Write states
  model_state.write(clustering)
  collected_docs_state.write(collected_documents)
  collected_embeddings_state.write(collected_embeddings)
  update_counter_state.write(update_counter)
  yield {
      "labels": cluster_labels,
      "docs": collected_documents,
      "id": list(collected_embeddings.keys()),
      "counter": update_counter,
  }

GetUpdates is a DoFn that prints the cluster assigned to each Twitter message every time a new message arrives.

updated_clusters = clustering | "Format Update" >> beam.ParDo(GetUpdates())