#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["SentenceTransformerEmbeddings"]
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
import apache_beam as beam
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from sentence_transformers import SentenceTransformer
# TODO: https://github.com/apache/beam/issues/29621
# Use HuggingFaceModelHandlerTensor once the import issue is fixed.
# Right now, the hugging face model handler import torch and tensorflow
# at the same time, which adds too much weigth to the container unnecessarily.
class _SentenceTransformerModelHandler(ModelHandler):
"""
Note: Intended for internal use and guarantees no backwards compatibility.
"""
def __init__(
self,
model_name: str,
model_class: Callable,
load_model_args: Optional[dict] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_seq_length: Optional[int] = None,
large_model: bool = False,
**kwargs):
self._max_seq_length = max_seq_length
self.model_name = model_name
self._model_class = model_class
self._load_model_args = load_model_args
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._large_model = large_model
self._kwargs = kwargs
def run_inference(
self,
batch: Sequence[str],
model: SentenceTransformer,
inference_args: Optional[Dict[str, Any]] = None,
):
inference_args = inference_args or {}
return model.encode(batch, **inference_args)
def load_model(self):
model = self._model_class(self.model_name, **self._load_model_args)
if self._max_seq_length:
model.max_seq_length = self._max_seq_length
return model
def share_model_across_processes(self) -> bool:
return self._large_model
def batch_elements_kwargs(self) -> Mapping[str, Any]:
batch_sizes = {}
if self._min_batch_size:
batch_sizes["min_batch_size"] = self._min_batch_size
if self._max_batch_size:
batch_sizes["max_batch_size"] = self._max_batch_size
return batch_sizes