################################################################################ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ import abc from abc import abstractmethod from typing import Iterable, Any, Dict, List from apache_beam.runners.worker.bundle_processor import TimerInfo, DataOutputOperation from apache_beam.runners.worker.operations import Operation from apache_beam.utils import windowed_value from apache_beam.utils.windowed_value import WindowedValue from pyflink.common.constants import DEFAULT_OUTPUT_TAG from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction from pyflink.fn_execution.table.operations import BundleOperation from pyflink.fn_execution.profiler import Profiler class OutputProcessor(abc.ABC): @abstractmethod def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]): pass def close(self): pass class NetworkOutputProcessor(OutputProcessor): def __init__(self, consumer): assert isinstance(consumer, DataOutputOperation) self._consumer = consumer self._value_coder_impl = consumer.windowed_coder.wrapped_value_coder.get_impl()._value_coder def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]): output_stream = self._consumer.output_stream self._value_coder_impl.encode_to_stream(results, output_stream, True) self._value_coder_impl._output_stream.maybe_flush() def close(self): self._value_coder_impl._output_stream.close() class IntermediateOutputProcessor(OutputProcessor): def __init__(self, consumer): self._consumer = consumer def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]): self._consumer.process(windowed_value.with_value(results)) class FunctionOperation(Operation): """ Base class of function operation that will execute StatelessFunction or StatefulFunction for each input element. """ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend): super(FunctionOperation, self).__init__(name, spec, counter_factory, sampler) self._output_processors = self._create_output_processors( consumers ) # type: Dict[str, List[OutputProcessor]] self.operation_cls = operation_cls self.operator_state_backend = operator_state_backend self.operation = self.generate_operation() self.process_element = self.operation.process_element self.operation.open() if spec.serialized_fn.profile_enabled: self._profiler = Profiler() else: self._profiler = None if isinstance(spec.serialized_fn, UserDefinedDataStreamFunction): self._has_side_output = spec.serialized_fn.has_side_output else: # it doesn't support side output in Table API & SQL self._has_side_output = False if not self._has_side_output: self._main_output_processor = self._output_processors[DEFAULT_OUTPUT_TAG][0] def setup(self, data_sampler=None): super().setup(data_sampler) def start(self): with self.scoped_start_state: super(FunctionOperation, self).start() if self._profiler: self._profiler.start() def finish(self): with self.scoped_finish_state: super(FunctionOperation, self).finish() self.operation.finish() if self._profiler: self._profiler.close() def needs_finalization(self): return False def reset(self): super(FunctionOperation, self).reset() def teardown(self): with self.scoped_finish_state: self.operation.close() for processors in self._output_processors.values(): for p in processors: p.close() def progress_metrics(self): metrics = super(FunctionOperation, self).progress_metrics() metrics.processed_elements.measured.output_element_counts.clear() tag = None receiver = self.receivers[0] metrics.processed_elements.measured.output_element_counts[ str(tag) ] = receiver.opcounter.element_counter.value() return metrics def process(self, o: WindowedValue): with self.scoped_process_state: if self._has_side_output: for value in o.value: for tag, row in self.process_element(value): for p in self._output_processors.get(tag, []): p.process_outputs(o, [row]) else: if isinstance(self.operation, BundleOperation): for value in o.value: self.process_element(value) self._main_output_processor.process_outputs(o, self.operation.finish_bundle()) else: for value in o.value: self._main_output_processor.process_outputs( o, self.operation.process_element(value) ) def monitoring_infos(self, transform_id, tag_to_pcollection_id): """ Only pass user metric to Java :param tag_to_pcollection_id: useless for user metric """ return super().user_monitoring_infos(transform_id) @staticmethod def _create_output_processors(consumers_map): def _create_processor(consumer): if isinstance(consumer, DataOutputOperation): return NetworkOutputProcessor(consumer) else: return IntermediateOutputProcessor(consumer) return { tag: [_create_processor(c) for c in consumers] for tag, consumers in consumers_map.items() } @abstractmethod def generate_operation(self): pass class StatelessFunctionOperation(FunctionOperation): def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend): super(StatelessFunctionOperation, self).__init__( name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend ) def generate_operation(self): if self.operator_state_backend is not None: return self.operation_cls(self.spec.serialized_fn, self.operator_state_backend) else: return self.operation_cls(self.spec.serialized_fn) class StatefulFunctionOperation(FunctionOperation): def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls, keyed_state_backend, operator_state_backend): self._keyed_state_backend = keyed_state_backend self._reusable_windowed_value = windowed_value.create(None, -1, None, None) super(StatefulFunctionOperation, self).__init__( name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend ) def generate_operation(self): if self.operator_state_backend is not None: return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend, self.operator_state_backend) else: return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend) def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo): # ignore timer_family_id self.operation.add_timer_info(timer_info) def process_timer(self, tag, timer_data): if self._has_side_output: # the field user_key holds the timer data for tag, row in self.operation.process_timer(timer_data.user_key): for p in self._output_processors.get(tag, []): p.process_outputs(self._reusable_windowed_value, [row]) else: self._main_output_processor.process_outputs( self._reusable_windowed_value, # the field user_key holds the timer data self.operation.process_timer(timer_data.user_key), )