Online Clustering Example
The online clustering example demonstrates how to set up a real-time clustering pipeline that can read text from Pub/Sub, convert the text into an embedding using a language model, and cluster the text using BIRCH.
Dataset for Clustering
This example uses a dataset called emotion that contains 20,000 English Twitter messages with 6 basic emotions: anger, fear, joy, love, sadness, and surprise. The dataset has three splits: train, validation, and test. Because it contains the text and the category (class) of the dataset, it’s a supervised dataset. To access this dataset, use the Hugging Face datasets page.
The following text shows examples from the train split of the dataset:
Text | Type of emotion |
---|---|
im grabbing a minute to post i feel greedy wrong | Anger |
i am ever feeling nostalgic about the fireplace i will know that it is still on the property | Love |
ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny | Fear |
on a boat trip to denmark | Joy |
i feel you know basically like a fake in the realm of science fiction | Sadness |
i began having them several times a week feeling tortured by the hallucinations moving people and figures sounds and vibrations | Fear |
Clustering Algorithm
For the clustering of tweets, we use an incremental clustering algorithm called BIRCH. It stands for balanced iterative reducing and clustering using hierarchies, and it is an unsupervised data mining algorithm used to perform hierarchical clustering over particularly large datasets. An advantage of BIRCH is its ability to incrementally and dynamically cluster incoming, multi-dimensional metric data points in an attempt to produce the best quality clustering for a given set of resources (memory and time constraints).
Ingestion to Pub/Sub
The example starts by ingesting the data into Pub/Sub so that we can read the tweets from Pub/Sub while clustering. Pub/Sub is a messaging service for exchanging event data among applications and services. Streaming analytics and data integration pipelines use Pub/Sub to ingest and distribute data.
You can find the full example code for ingesting data into Pub/Sub in GitHub
The file structure for the ingestion pipeline is shown in the following diagram:
write_data_to_pubsub_pipeline/
├── pipeline/
│ ├── __init__.py
│ ├── options.py
│ └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py
pipeline/utils.py
contains the code for loading the emotion dataset and two beam.DoFn
that are used for data transformation.
pipeline/options.py
contains the pipeline options to configure the Dataflow pipeline.
config.py
defines some variables that are used multiple times, like GCP PROJECT_ID and NUM_WORKERS.
setup.py
defines the packages and requirements for the pipeline to run.
main.py
contains the pipeline code and some additional function used for running the pipeline.
Run the Pipeline
First, install the required packages.
- Locally on your machine:
python main.py
- On GCP for Dataflow:
python main.py --mode cloud
The write_data_to_pubsub_pipeline
contains four different transforms:
- Load the emotion dataset using Hugging Face datasets (for simplicity, we take samples from three classes instead of six).
- Associate each piece of text with a unique identifier (UID).
- Convert the text into the format that Pub/Sub expects.
- Write the formatted message to Pub/Sub.
Clustering on Streaming Data
After ingesting the data to Pub/Sub, examine the second pipeline, where we read the streaming message from Pub/Sub, convert the text to a embedding using a language model, and cluster the embedding using BIRCH.
You can find the full example code for all the steps mentioned previously in GitHub.
The file structure for clustering_pipeline is:
clustering_pipeline/
├── pipeline/
│ ├── __init__.py
│ ├── options.py
│ └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py
pipeline/transformations.py
contains the code for the different beam.DoFn
that are used in the pipeline.
pipeline/options.py
contains the pipeline options to configure the Dataflow pipeline.
config.py
defines variables that are used multiple times, like Google Cloud PROJECT_ID and NUM_WORKERS.
setup.py
defines the packages and requirements for the pipeline to run.
main.py
contains the pipeline code and some additional functions used for running the pipeline.
Run the Pipeline
Install the required packages and push data to Pub/Sub.
- Locally on your machine:
python main.py
- On GCP for Dataflow:
python main.py --mode cloud
The pipeline can be broken down into the following steps:
- Read the message from Pub/Sub.
- Convert the Pub/Sub message into a
PCollection
of dictionaries where the key is the UID and the value is the Twitter text. - Encode the text into transformer-readable token ID integers using a tokenizer.
- Use RunInference to get the vector embedding from a transformer-based language model.
- Normalize the embedding for clustering.
- Perform BIRCH clustering using stateful processing.
- Print the texts assigned to clusters.
The following code shows the first two steps of the pipeline, where a message from Pub/Sub is read and converted into a dictionary.
The next sections examine three important pipeline steps:
- Tokenize the text.
- Feed the tokenized text to get embedding from a transformer-based language model.
- Perform clustering using stateful processing.
Get Embedding from a Language Model
In order cluster text data, you need to map the text into vectors of numerical values suitable for statistical analysis. This example uses a transformer-based language model called sentence-transformers/stsb-distilbert-base/stsb-distilbert-base. It maps sentences and paragraphs to a 768 dimensional dense vector space, and you can use it for tasks like clustering or semantic search.
Because the language model is expecting a tokenized input instead of raw text, start by tokenizing the text. Tokenization is a preprocessing task that transforms text so that it can be fed into the model for getting predictions.
Here, tokenize_sentence
is a function that takes a dictionary with a text and an ID, tokenizes the text, and returns a tuple (text, id) and the tokenized output.
Tokenized output is then passed to the language model for getting the embeddings. For getting embeddings from the language model, we use RunInference()
from Apache Beam.
To make better clusters, after getting the embedding for each piece of Twitter text, the embeddings are normalized.
StatefulOnlineClustering
Because the data is streaming, you need to use an iterative clustering algorithm, like BIRCH. And because the algorithm is iterative, you need a mechanism to store the previous state so that when Twitter text arrives, it can be updated. Stateful processing enables a DoFn
to have persistent state, which can be read and written during the processing of each element. For more information about stateful processing, see Stateful processing with Apache Beam.
In this example, every time a new message is read from Pub/Sub, you retrieve the existing state of the clustering model, update it, and write it back to the state.
Because BIRCH doesn’t support parallelization, you need to make sure that only one worker is doing all of the stateful processing. To do that, use Beam.Map
to associate each text to the same key 1
.
StatefulOnlineClustering
is a DoFn
that takes an embedding of a text and updates the clustering model. To store the state, it uses the ReadModifyWriteStateSpec
state object, which acts as a container for storage.
class StatefulOnlineClustering(beam.DoFn):
BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())
This example declares four different ReadModifyWriteStateSpec objects
:
BIRCH_MODEL_SPEC
holds the state of clustering model.DATA_ITEMS_SPEC
holds the Twitter texts seen so far.EMBEDDINGS_SPEC
holds the normalized embeddings.UPDATE_COUNTER_SPEC
holds the number of texts processed.
These ReadModifyWriteStateSpec
objects are passed as an additional argument to the process
function. When a news item comes in, we retrieve the existing state of the different objects, update them, and then write them back as persistent shared state.
def process(
self,
element,
model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
*args,
**kwargs,
):
"""
Takes the embedding of a document and updates the clustering model
Args:
element: The input element to be processed.
model_state: This is the state of the clustering model. It is a stateful parameter,
which means that it will be updated after each call to the process function.
collected_docs_state: This is a stateful dictionary that stores the documents that
have been processed so far.
collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
update_counter_state: This is a counter that keeps track of how many documents have been
processed.
"""
# 1. Initialise or load states
clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
collected_documents = collected_docs_state.read() or {}
collected_embeddings = collected_embeddings_state.read() or {}
update_counter = update_counter_state.read() or Counter()
# 2. Extract document, add to state, and add to clustering model
_, doc = element
doc_id = doc["id"]
embedding_vector = doc["embedding"]
collected_embeddings[doc_id] = embedding_vector
collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
update_counter = len(collected_documents)
clustering.partial_fit(np.atleast_2d(embedding_vector))
# 3. Predict cluster labels of collected documents
cluster_labels = clustering.predict(
np.array(list(collected_embeddings.values())))
# 4. Write states
model_state.write(clustering)
collected_docs_state.write(collected_documents)
collected_embeddings_state.write(collected_embeddings)
update_counter_state.write(update_counter)
yield {
"labels": cluster_labels,
"docs": collected_documents,
"id": list(collected_embeddings.keys()),
"counter": update_counter,
}
GetUpdates
is a DoFn
that prints the cluster assigned to each Twitter message every time a new message arrives.
Last updated on 2024/09/17
Have you found everything you were looking for?
Was it all useful and clear? Is there anything that you would like to change? Let us know!