Source code for apache_beam.transforms.external_java

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for the Java external transforms."""

import argparse
import logging
import subprocess
import sys

import grpc
from mock import patch

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder

# Protect against environments where apitools library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
  from apache_beam.runners.dataflow.internal import apiclient as _apiclient
except ImportError:
  apiclient = None
else:
  apiclient = _apiclient
# pylint: enable=wrong-import-order, wrong-import-position


[docs]class JavaExternalTransformTest(object): # This will be overwritten if set via a flag. expansion_service_jar = None # type: str expansion_service_port = None # type: int class _RunWithExpansion(object): def __init__(self): self._server = None def __enter__(self): if not (JavaExternalTransformTest.expansion_service_jar or JavaExternalTransformTest.expansion_service_port): raise RuntimeError('No expansion service jar or port provided.') JavaExternalTransformTest.expansion_service_port = ( JavaExternalTransformTest.expansion_service_port or 8091) jar = JavaExternalTransformTest.expansion_service_jar port = JavaExternalTransformTest.expansion_service_port # Start the java server and wait for it to be ready. if jar: self._server = subprocess.Popen(['java', '-jar', jar, str(port)]) address = 'localhost:%s' % str(port) with grpc.insecure_channel(address) as channel: grpc.channel_ready_future(channel).result() def __exit__(self, type, value, traceback): if self._server: self._server.kill() self._server = None
[docs] @staticmethod def test_java_expansion_dataflow(): if apiclient is None: return # This test does not actually running the pipeline in Dataflow. It just # tests the translation to a Dataflow job request. with patch.object(apiclient.DataflowApplicationClient, 'create_job') as mock_create_job: with JavaExternalTransformTest._RunWithExpansion(): pipeline_options = PipelineOptions([ '--runner=DataflowRunner', '--project=dummyproject', '--region=some-region1', '--experiments=beam_fn_api', '--temp_location=gs://dummybucket/' ]) # Run a simple count-filtered-letters pipeline. JavaExternalTransformTest.run_pipeline( pipeline_options, JavaExternalTransformTest.expansion_service_port, False) mock_args = mock_create_job.call_args_list assert mock_args args, kwargs = mock_args[0] job = args[0] job_str = '%s' % job assert 'beam:transforms:xlang:filter_less_than_eq' in job_str
[docs] @staticmethod def run_pipeline_with_expansion_service(pipeline_options): with JavaExternalTransformTest._RunWithExpansion(): # Run a simple count-filtered-letters pipeline. JavaExternalTransformTest.run_pipeline( pipeline_options, JavaExternalTransformTest.expansion_service_port, True)
[docs] @staticmethod def run_pipeline(pipeline_options, expansion_service, wait_until_finish=True): # The actual definitions of these transforms is in # org.apache.beam.runners.core.construction.TestExpansionService. TEST_COUNT_URN = "beam:transforms:xlang:count" TEST_FILTER_URN = "beam:transforms:xlang:filter_less_than_eq" # Run a simple count-filtered-letters pipeline. p = TestPipeline(options=pipeline_options) if isinstance(expansion_service, int): # Only the port was specified. expansion_service = 'localhost:%s' % str(expansion_service) res = ( p | beam.Create(list('aaabccxyyzzz')) | beam.Map(str) | beam.ExternalTransform( TEST_FILTER_URN, ImplicitSchemaPayloadBuilder({'data': 'middle'}), expansion_service) | beam.ExternalTransform(TEST_COUNT_URN, None, expansion_service) | beam.Map(lambda kv: '%s: %s' % kv)) assert_that(res, equal_to(['a: 3', 'b: 1', 'c: 2'])) result = p.run() if wait_until_finish: result.wait_until_finish()
if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--expansion_service_jar') parser.add_argument('--expansion_service_port') parser.add_argument('--expansion_service_target') parser.add_argument('--expansion_service_target_appendix') known_args, pipeline_args = parser.parse_known_args(sys.argv) if known_args.expansion_service_jar: JavaExternalTransformTest.expansion_service_jar = ( known_args.expansion_service_jar) JavaExternalTransformTest.expansion_service_port = int( known_args.expansion_service_port) pipeline_options = PipelineOptions(pipeline_args) JavaExternalTransformTest.run_pipeline_with_expansion_service( pipeline_options) elif known_args.expansion_service_target: pipeline_options = PipelineOptions(pipeline_args) JavaExternalTransformTest.run_pipeline( pipeline_options, beam.transforms.external.BeamJarExpansionService( known_args.expansion_service_target, gradle_appendix=known_args.expansion_service_target_appendix)) else: raise RuntimeError( "--expansion_service_jar or --expansion_service_target " "should be provided.")