#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
from typing import Set
from typing import Tuple
import apache_beam as beam
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import combiners
from apache_beam.transforms import trigger
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.typehints import with_input_types
from apache_beam.typehints import with_output_types
[docs]@with_input_types(int)
@with_output_types(int)
class CallSequenceEnforcingCombineFn(beam.CombineFn):
  instances = set()  # type: Set[CallSequenceEnforcingCombineFn]
  def __init__(self):
    super().__init__()
    self._setup_called = False
    self._teardown_called = False
[docs]  def setup(self, *args, **kwargs):
    assert not self._setup_called, 'setup should not be called twice'
    assert not self._teardown_called, 'setup should be called before teardown'
    # Keep track of instances so that we can check if teardown is called
    # properly after pipeline execution.
    self.instances.add(self)
    self._setup_called = True 
[docs]  def create_accumulator(self, *args, **kwargs):
    assert self._setup_called, 'setup should have been called'
    assert not self._teardown_called, 'teardown should not have been called'
    return 0 
[docs]  def merge_accumulators(self, accumulators, *args, **kwargs):
    assert self._setup_called, 'setup should have been called'
    assert not self._teardown_called, 'teardown should not have been called'
    return sum(accumulators) 
[docs]  def teardown(self, *args, **kwargs):
    assert self._setup_called, 'setup should have been called'
    assert not self._teardown_called, 'teardown should not be called twice'
    self._teardown_called = True  
[docs]@with_input_types(Tuple[None, str])
@with_output_types(Tuple[int, str])
class IndexAssigningDoFn(beam.DoFn):
  state_param = beam.DoFn.StateParam(
      userstate.CombiningValueStateSpec(
          'index', beam.coders.VarIntCoder(), CallSequenceEnforcingCombineFn()))
[docs]  def process(self, element, state=state_param):
    _, value = element
    current_index = state.read()
    yield current_index, value
    state.add(1)  
[docs]def run_combine(pipeline, input_elements=5, lift_combiners=True):
  # Calculate the expected result, which is the sum of an arithmetic sequence.
  # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10
  expected_result = input_elements * (input_elements - 1) / 2
  # Enable runtime type checking in order to cover TypeCheckCombineFn by
  # the test.
  pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True
  pipeline.get_pipeline_options().view_as(
      TypeOptions).allow_unsafe_triggers = True
  with pipeline as p:
    pcoll = p | 'Start' >> beam.Create(range(input_elements))
    # Certain triggers, such as AfterCount, are incompatible with combiner
    # lifting. We can use that fact to prevent combiners from being lifted.
    if not lift_combiners:
      pcoll |= beam.WindowInto(
          window.GlobalWindows(),
          trigger=trigger.AfterCount(input_elements),
          accumulation_mode=trigger.AccumulationMode.DISCARDING)
    # Pass an additional 'None' in order to cover _CurriedFn by the test.
    pcoll |= 'Do' >> beam.CombineGlobally(
        combiners.SingleInputTupleCombineFn(
            CallSequenceEnforcingCombineFn(), CallSequenceEnforcingCombineFn()),
        None).with_fanout(fanout=1)
    assert_that(pcoll, equal_to([(expected_result, expected_result)])) 
[docs]def run_pardo(pipeline, input_elements=10):
  with pipeline as p:
    _ = (
        p
        | 'Start' >> beam.Create(('Hello' for _ in range(input_elements)))
        | 'KeyWithNone' >> beam.Map(lambda elem: (None, elem))
        | 'Do' >> beam.ParDo(IndexAssigningDoFn()))