You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
flink/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py

226 lines
8.8 KiB
Python

################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import abc
from abc import abstractmethod
from typing import Iterable, Any, Dict, List
from apache_beam.runners.worker.bundle_processor import TimerInfo, DataOutputOperation
from apache_beam.runners.worker.operations import Operation
from apache_beam.utils import windowed_value
from apache_beam.utils.windowed_value import WindowedValue
from pyflink.common.constants import DEFAULT_OUTPUT_TAG
from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction
from pyflink.fn_execution.table.operations import BundleOperation
from pyflink.fn_execution.profiler import Profiler
class OutputProcessor(abc.ABC):
@abstractmethod
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
pass
def close(self):
pass
class NetworkOutputProcessor(OutputProcessor):
def __init__(self, consumer):
assert isinstance(consumer, DataOutputOperation)
self._consumer = consumer
self._value_coder_impl = consumer.windowed_coder.wrapped_value_coder.get_impl()._value_coder
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
output_stream = self._consumer.output_stream
self._value_coder_impl.encode_to_stream(results, output_stream, True)
self._value_coder_impl._output_stream.maybe_flush()
def close(self):
self._value_coder_impl._output_stream.close()
class IntermediateOutputProcessor(OutputProcessor):
def __init__(self, consumer):
self._consumer = consumer
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
self._consumer.process(windowed_value.with_value(results))
class FunctionOperation(Operation):
"""
Base class of function operation that will execute StatelessFunction or StatefulFunction for
each input element.
"""
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
operator_state_backend):
super(FunctionOperation, self).__init__(name, spec, counter_factory, sampler)
self._output_processors = self._create_output_processors(
consumers
) # type: Dict[str, List[OutputProcessor]]
self.operation_cls = operation_cls
self.operator_state_backend = operator_state_backend
self.operation = self.generate_operation()
self.process_element = self.operation.process_element
self.operation.open()
if spec.serialized_fn.profile_enabled:
self._profiler = Profiler()
else:
self._profiler = None
if isinstance(spec.serialized_fn, UserDefinedDataStreamFunction):
self._has_side_output = spec.serialized_fn.has_side_output
else:
# it doesn't support side output in Table API & SQL
self._has_side_output = False
if not self._has_side_output:
self._main_output_processor = self._output_processors[DEFAULT_OUTPUT_TAG][0]
def setup(self, data_sampler=None):
super().setup(data_sampler)
def start(self):
with self.scoped_start_state:
super(FunctionOperation, self).start()
if self._profiler:
self._profiler.start()
def finish(self):
with self.scoped_finish_state:
super(FunctionOperation, self).finish()
self.operation.finish()
if self._profiler:
self._profiler.close()
def needs_finalization(self):
return False
def reset(self):
super(FunctionOperation, self).reset()
def teardown(self):
with self.scoped_finish_state:
self.operation.close()
for processors in self._output_processors.values():
for p in processors:
p.close()
def progress_metrics(self):
metrics = super(FunctionOperation, self).progress_metrics()
metrics.processed_elements.measured.output_element_counts.clear()
tag = None
receiver = self.receivers[0]
metrics.processed_elements.measured.output_element_counts[
str(tag)
] = receiver.opcounter.element_counter.value()
return metrics
def process(self, o: WindowedValue):
with self.scoped_process_state:
if self._has_side_output:
for value in o.value:
for tag, row in self.process_element(value):
for p in self._output_processors.get(tag, []):
p.process_outputs(o, [row])
else:
if isinstance(self.operation, BundleOperation):
for value in o.value:
self.process_element(value)
self._main_output_processor.process_outputs(o, self.operation.finish_bundle())
else:
for value in o.value:
self._main_output_processor.process_outputs(
o, self.operation.process_element(value)
)
def monitoring_infos(self, transform_id, tag_to_pcollection_id):
"""
Only pass user metric to Java
:param tag_to_pcollection_id: useless for user metric
"""
return super().user_monitoring_infos(transform_id)
@staticmethod
def _create_output_processors(consumers_map):
def _create_processor(consumer):
if isinstance(consumer, DataOutputOperation):
return NetworkOutputProcessor(consumer)
else:
return IntermediateOutputProcessor(consumer)
return {
tag: [_create_processor(c) for c in consumers]
for tag, consumers in consumers_map.items()
}
@abstractmethod
def generate_operation(self):
pass
class StatelessFunctionOperation(FunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
operator_state_backend):
super(StatelessFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend
)
def generate_operation(self):
if self.operator_state_backend is not None:
return self.operation_cls(self.spec.serialized_fn, self.operator_state_backend)
else:
return self.operation_cls(self.spec.serialized_fn)
class StatefulFunctionOperation(FunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
keyed_state_backend, operator_state_backend):
self._keyed_state_backend = keyed_state_backend
self._reusable_windowed_value = windowed_value.create(None, -1, None, None)
super(StatefulFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend
)
def generate_operation(self):
if self.operator_state_backend is not None:
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend,
self.operator_state_backend)
else:
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend)
def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo):
# ignore timer_family_id
self.operation.add_timer_info(timer_info)
def process_timer(self, tag, timer_data):
if self._has_side_output:
# the field user_key holds the timer data
for tag, row in self.operation.process_timer(timer_data.user_key):
for p in self._output_processors.get(tag, []):
p.process_outputs(self._reusable_windowed_value, [row])
else:
self._main_output_processor.process_outputs(
self._reusable_windowed_value,
# the field user_key holds the timer data
self.operation.process_timer(timer_data.user_key),
)