apache_beam.ml.inference.tensorrt_inference module

class apache_beam.ml.inference.tensorrt_inference.TensorRTEngine(engine: <sphinx.ext.autodoc.importer._MockObject object at 0x7f3147f185b0>)[source]

Bases: object

Implementation of the TensorRTEngine class which handles allocations associated with TensorRT engine.

Example Usage:

TensorRTEngine(engine)
Parameters:engine – trt.ICudaEngine object that contains TensorRT engine
cpu_allocations = None

Setup I/O bindings.

get_engine_attrs()[source]

Returns TensorRT engine attributes.

class apache_beam.ml.inference.tensorrt_inference.TensorRTEngineHandlerNumPy(min_batch_size: int, max_batch_size: int, *, inference_fn: Callable[[Sequence[numpy.ndarray], apache_beam.ml.inference.tensorrt_inference.TensorRTEngine, Optional[Dict[str, Any]]], Iterable[apache_beam.ml.inference.base.PredictionResult]] = <function _default_tensorRT_inference_fn>, **kwargs)[source]

Bases: apache_beam.ml.inference.base.ModelHandler

Implementation of the ModelHandler interface for TensorRT.

Example Usage:

pcoll | RunInference(
    TensorRTEngineHandlerNumPy(
      min_batch_size=1,
      max_batch_size=1,
      engine_path="my_uri"))

NOTE: This API and its implementation are under development and do not provide backward compatibility guarantees.

Parameters:
  • min_batch_size – minimum accepted batch size.
  • max_batch_size – maximum accepted batch size.
  • inference_fn – the inference function to use on RunInference calls. default: _default_tensorRT_inference_fn
  • kwargs – Additional arguments like ‘engine_path’ and ‘onnx_path’ are currently supported.

See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/ for details

batch_elements_kwargs()[source]

Sets min_batch_size and max_batch_size of a TensorRT engine.

load_model() → apache_beam.ml.inference.tensorrt_inference.TensorRTEngine[source]

Loads and initializes a TensorRT engine for processing.

load_onnx() → Tuple[<sphinx.ext.autodoc.importer._MockObject object at 0x7f3147f61f10>, <sphinx.ext.autodoc.importer._MockObject object at 0x7f3147f61eb0>][source]

Loads and parses an onnx model for processing.

build_engine(network: <sphinx.ext.autodoc.importer._MockObject object at 0x7f3147f76700>, builder: <sphinx.ext.autodoc.importer._MockObject object at 0x7f3147f766a0>) → apache_beam.ml.inference.tensorrt_inference.TensorRTEngine[source]

Build an engine according to parsed/created network.

run_inference(batch: Sequence[numpy.ndarray], engine: apache_beam.ml.inference.tensorrt_inference.TensorRTEngine, inference_args: Optional[Dict[str, Any]] = None) → Iterable[apache_beam.ml.inference.base.PredictionResult][source]

Runs inferences on a batch of Tensors and returns an Iterable of TensorRT Predictions.

Parameters:
  • batch – A np.ndarray or a np.ndarray that represents a concatenation of multiple arrays as a batch.
  • engine – A TensorRT engine.
  • inference_args – Any additional arguments for an inference that are not applicable to TensorRT.
Returns:

An Iterable of type PredictionResult.

get_num_bytes(batch: Sequence[numpy.ndarray]) → int[source]
Returns:The number of bytes of data for a batch of Tensors.
get_metrics_namespace() → str[source]

Returns a namespace for metrics collected by the RunInference transform.

apache_beam.ml.inference.tensorrt_inference.experimental(*, label='experimental', since=None, current=None, extra_message=None, custom_message=None)

Decorates an API with a deprecated or experimental annotation.

Parameters:
  • label – the kind of annotation (‘deprecated’ or ‘experimental’).
  • since – the version that causes the annotation.
  • current – the suggested replacement function.
  • extra_message – an optional additional message.
  • custom_message – if the default message does not suffice, the message can be changed using this argument. A string whit replacement tokens. A replecement string is were the previus args will be located on the custom message. The following replacement strings can be used: %name% -> API.__name__ %since% -> since (Mandatory for the decapreted annotation) %current% -> current %extra% -> extra_message
Returns:

The decorator for the API.