#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import asyncio
import logging
import os
import subprocess
import sys
import threading
import time
import uuid
from collections.abc import Iterable
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from typing import Optional
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils import subprocess_server
from openai import AsyncOpenAI
from openai import OpenAI
try:
  # VLLM logging config breaks beam logging.
  os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
  import vllm  # pylint: disable=unused-import
  logging.info('vllm module successfully imported.')
  os.environ["VLLM_CONFIGURE_LOGGING"] = "1"
except ModuleNotFoundError:
  msg = 'vllm module was not found. This is ok as long as the specified ' \
    'runner has vllm dependencies installed.'
  logging.warning(msg)
__all__ = [
    'OpenAIChatMessage',
    'VLLMCompletionsModelHandler',
    'VLLMChatModelHandler',
]
[docs]
@dataclass(frozen=True)
class OpenAIChatMessage():
  """"
  Dataclass containing previous chat messages in conversation.
  Role is the entity that sent the message (either 'user' or 'system').
  Content is the contents of the message.
  """
  role: str
  content: str 
def start_process(cmd) -> tuple[subprocess.Popen, int]:
  port, = subprocess_server.pick_port(None)
  cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd]  # pylint: disable=not-an-iterable
  logging.info("Starting service with %s", str(cmd).replace("',", "'"))
  process = subprocess.Popen(
      cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  # Emit the output of this command as info level logging.
  def log_stdout():
    line = process.stdout.readline()
    while line:
      # The log obtained from stdout is bytes, decode it into string.
      # Remove newline via rstrip() to not print an empty line.
      logging.info(line.decode(errors='backslashreplace').rstrip())
      line = process.stdout.readline()
  t = threading.Thread(target=log_stdout)
  t.daemon = True
  t.start()
  return process, port
def getVLLMClient(port) -> OpenAI:
  openai_api_key = "EMPTY"
  openai_api_base = f"http://localhost:{port}/v1"
  return OpenAI(
      api_key=openai_api_key,
      base_url=openai_api_base,
  )
def getAsyncVLLMClient(port) -> AsyncOpenAI:
  openai_api_key = "EMPTY"
  openai_api_base = f"http://localhost:{port}/v1"
  return AsyncOpenAI(
      api_key=openai_api_key,
      base_url=openai_api_base,
  )
class _VLLMModelServer():
  def __init__(self, model_name: str, vllm_server_kwargs: dict[str, str]):
    self._model_name = model_name
    self._vllm_server_kwargs = vllm_server_kwargs
    self._server_started = False
    self._server_process = None
    self._server_port: int = -1
    self._server_process_lock = threading.RLock()
    self.start_server()
  def start_server(self, retries=3):
    with self._server_process_lock:
      if not self._server_started:
        server_cmd = [
            sys.executable,
            '-m',
            'vllm.entrypoints.openai.api_server',
            '--model',
            self._model_name,
            '--port',
            '{{PORT}}',
        ]
        for k, v in self._vllm_server_kwargs.items():
          server_cmd.append(f'--{k}')
          # Only add values for commands with value part.
          if v is not None:
            server_cmd.append(v)
        self._server_process, self._server_port = start_process(server_cmd)
      self.check_connectivity(retries)
  def get_server_port(self) -> int:
    if not self._server_started:
      self.start_server()
    return self._server_port
  def check_connectivity(self, retries=3):
    with getVLLMClient(self._server_port) as client:
      while self._server_process.poll() is None:
        try:
          models = client.models.list().data
          logging.info('models: %s' % models)
          if len(models) > 0:
            self._server_started = True
            return
        except:  # pylint: disable=bare-except
          pass
        # Sleep while bringing up the process
        time.sleep(5)
      if retries == 0:
        self._server_started = False
        raise Exception(
            "Failed to start vLLM server, polling process exited with code " +
            "%s.  Next time a request is tried, the server will be restarted" %
            self._server_process.poll())
      else:
        self.start_server(retries - 1)
[docs]
class VLLMCompletionsModelHandler(ModelHandler[str,
                                               PredictionResult,
                                               _VLLMModelServer]):
  def __init__(
      self,
      model_name: str,
      vllm_server_kwargs: Optional[dict[str, str]] = None):
    """Implementation of the ModelHandler interface for vLLM using text as
    input.
    Example Usage::
      pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m'))
    Args:
      model_name: The vLLM model. See
        https://docs.vllm.ai/en/latest/models/supported_models.html for
        supported models.
      vllm_server_kwargs: Any additional kwargs to be passed into your vllm
        server when it is being created. Will be invoked using
        `python -m vllm.entrypoints.openai.api_serverv <beam provided args>
        <vllm_server_kwargs>`. For example, you could pass
        `{'echo': 'true'}` to prepend new messages with the previous message.
        For a list of possible kwargs, see
        https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api
    """
    self._model_name = model_name
    self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
    self._env_vars = {}
[docs]
  def load_model(self) -> _VLLMModelServer:
    return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) 
  async def _async_run_inference(
      self,
      batch: Sequence[str],
      model: _VLLMModelServer,
      inference_args: Optional[dict[str, Any]] = None
  ) -> Iterable[PredictionResult]:
    inference_args = inference_args or {}
    async with getAsyncVLLMClient(model.get_server_port()) as client:
      try:
        async_predictions = [
            client.completions.create(
                model=self._model_name, prompt=prompt, **inference_args)
            for prompt in batch
        ]
        responses = await asyncio.gather(*async_predictions)
      except Exception as e:
        model.check_connectivity()
        raise e
    return [PredictionResult(x, y) for x, y in zip(batch, responses)]
[docs]
  def run_inference(
      self,
      batch: Sequence[str],
      model: _VLLMModelServer,
      inference_args: Optional[dict[str, Any]] = None
  ) -> Iterable[PredictionResult]:
    """Runs inferences on a batch of text strings.
    Args:
      batch: A sequence of examples as text strings.
      model: A _VLLMModelServer containing info for connecting to the server.
      inference_args: Any additional arguments for an inference.
    Returns:
      An Iterable of type PredictionResult.
    """
    return asyncio.run(self._async_run_inference(batch, model, inference_args)) 
[docs]
  def share_model_across_processes(self) -> bool:
    return True 
 
[docs]
class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage],
                                        PredictionResult,
                                        _VLLMModelServer]):
  def __init__(
      self,
      model_name: str,
      chat_template_path: Optional[str] = None,
      vllm_server_kwargs: Optional[dict[str, str]] = None):
    """ Implementation of the ModelHandler interface for vLLM using previous
    messages as input.
    Example Usage::
      pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m'))
    Args:
      model_name: The vLLM model. See
        https://docs.vllm.ai/en/latest/models/supported_models.html for
        supported models.
      chat_template_path: Path to a chat template. This file must be accessible
        from your runner's execution environment, so it is recommended to use
        a cloud based file storage system (e.g. Google Cloud Storage).
        For info on chat templates, see:
        https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template
      vllm_server_kwargs: Any additional kwargs to be passed into your vllm
        server when it is being created. Will be invoked using
        `python -m vllm.entrypoints.openai.api_serverv <beam provided args>
        <vllm_server_kwargs>`. For example, you could pass
        `{'echo': 'true'}` to prepend new messages with the previous message.
        For a list of possible kwargs, see
        https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api
    """
    self._model_name = model_name
    self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
    self._env_vars = {}
    self._chat_template_path = chat_template_path
    self._chat_file = f'template-{uuid.uuid4().hex}.jinja'
[docs]
  def load_model(self) -> _VLLMModelServer:
    chat_template_contents = ''
    if self._chat_template_path is not None:
      local_chat_template_path = os.path.join(os.getcwd(), self._chat_file)
      if not os.path.exists(local_chat_template_path):
        with FileSystems.open(self._chat_template_path) as fin:
          chat_template_contents = fin.read().decode()
        with open(local_chat_template_path, 'a') as f:
          f.write(chat_template_contents)
      self._vllm_server_kwargs['chat_template'] = local_chat_template_path
    return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) 
  async def _async_run_inference(
      self,
      batch: Sequence[Sequence[OpenAIChatMessage]],
      model: _VLLMModelServer,
      inference_args: Optional[dict[str, Any]] = None
  ) -> Iterable[PredictionResult]:
    inference_args = inference_args or {}
    async with getAsyncVLLMClient(model.get_server_port()) as client:
      try:
        async_predictions = [
            client.chat.completions.create(
                model=self._model_name,
                messages=[{
                    "role": message.role, "content": message.content
                } for message in messages],
                **inference_args) for messages in batch
        ]
        predictions = await asyncio.gather(*async_predictions)
      except Exception as e:
        model.check_connectivity()
        raise e
    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
[docs]
  def run_inference(
      self,
      batch: Sequence[Sequence[OpenAIChatMessage]],
      model: _VLLMModelServer,
      inference_args: Optional[dict[str, Any]] = None
  ) -> Iterable[PredictionResult]:
    """Runs inferences on a batch of text strings.
    Args:
      batch: A sequence of examples as OpenAI messages.
      model: A _VLLMModelServer for connecting to the spun up server.
      inference_args: Any additional arguments for an inference.
    Returns:
      An Iterable of type PredictionResult.
    """
    return asyncio.run(self._async_run_inference(batch, model, inference_args)) 
[docs]
  def share_model_across_processes(self) -> bool:
    return True