mirror of https://github.com/apache/flink.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
226 lines
8.8 KiB
Python
226 lines
8.8 KiB
Python
################################################################################
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
################################################################################
|
|
import abc
|
|
from abc import abstractmethod
|
|
from typing import Iterable, Any, Dict, List
|
|
|
|
from apache_beam.runners.worker.bundle_processor import TimerInfo, DataOutputOperation
|
|
from apache_beam.runners.worker.operations import Operation
|
|
from apache_beam.utils import windowed_value
|
|
from apache_beam.utils.windowed_value import WindowedValue
|
|
|
|
from pyflink.common.constants import DEFAULT_OUTPUT_TAG
|
|
from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction
|
|
from pyflink.fn_execution.table.operations import BundleOperation
|
|
from pyflink.fn_execution.profiler import Profiler
|
|
|
|
|
|
class OutputProcessor(abc.ABC):
|
|
|
|
@abstractmethod
|
|
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
|
|
pass
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class NetworkOutputProcessor(OutputProcessor):
|
|
|
|
def __init__(self, consumer):
|
|
assert isinstance(consumer, DataOutputOperation)
|
|
self._consumer = consumer
|
|
self._value_coder_impl = consumer.windowed_coder.wrapped_value_coder.get_impl()._value_coder
|
|
|
|
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
|
|
output_stream = self._consumer.output_stream
|
|
self._value_coder_impl.encode_to_stream(results, output_stream, True)
|
|
self._value_coder_impl._output_stream.maybe_flush()
|
|
|
|
def close(self):
|
|
self._value_coder_impl._output_stream.close()
|
|
|
|
|
|
class IntermediateOutputProcessor(OutputProcessor):
|
|
|
|
def __init__(self, consumer):
|
|
self._consumer = consumer
|
|
|
|
def process_outputs(self, windowed_value: WindowedValue, results: Iterable[Any]):
|
|
self._consumer.process(windowed_value.with_value(results))
|
|
|
|
|
|
class FunctionOperation(Operation):
|
|
"""
|
|
Base class of function operation that will execute StatelessFunction or StatefulFunction for
|
|
each input element.
|
|
"""
|
|
|
|
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
|
|
operator_state_backend):
|
|
super(FunctionOperation, self).__init__(name, spec, counter_factory, sampler)
|
|
self._output_processors = self._create_output_processors(
|
|
consumers
|
|
) # type: Dict[str, List[OutputProcessor]]
|
|
self.operation_cls = operation_cls
|
|
self.operator_state_backend = operator_state_backend
|
|
self.operation = self.generate_operation()
|
|
self.process_element = self.operation.process_element
|
|
self.operation.open()
|
|
if spec.serialized_fn.profile_enabled:
|
|
self._profiler = Profiler()
|
|
else:
|
|
self._profiler = None
|
|
|
|
if isinstance(spec.serialized_fn, UserDefinedDataStreamFunction):
|
|
self._has_side_output = spec.serialized_fn.has_side_output
|
|
else:
|
|
# it doesn't support side output in Table API & SQL
|
|
self._has_side_output = False
|
|
if not self._has_side_output:
|
|
self._main_output_processor = self._output_processors[DEFAULT_OUTPUT_TAG][0]
|
|
|
|
def setup(self, data_sampler=None):
|
|
super().setup(data_sampler)
|
|
|
|
def start(self):
|
|
with self.scoped_start_state:
|
|
super(FunctionOperation, self).start()
|
|
if self._profiler:
|
|
self._profiler.start()
|
|
|
|
def finish(self):
|
|
with self.scoped_finish_state:
|
|
super(FunctionOperation, self).finish()
|
|
self.operation.finish()
|
|
if self._profiler:
|
|
self._profiler.close()
|
|
|
|
def needs_finalization(self):
|
|
return False
|
|
|
|
def reset(self):
|
|
super(FunctionOperation, self).reset()
|
|
|
|
def teardown(self):
|
|
with self.scoped_finish_state:
|
|
self.operation.close()
|
|
for processors in self._output_processors.values():
|
|
for p in processors:
|
|
p.close()
|
|
|
|
def progress_metrics(self):
|
|
metrics = super(FunctionOperation, self).progress_metrics()
|
|
metrics.processed_elements.measured.output_element_counts.clear()
|
|
tag = None
|
|
receiver = self.receivers[0]
|
|
metrics.processed_elements.measured.output_element_counts[
|
|
str(tag)
|
|
] = receiver.opcounter.element_counter.value()
|
|
return metrics
|
|
|
|
def process(self, o: WindowedValue):
|
|
with self.scoped_process_state:
|
|
if self._has_side_output:
|
|
for value in o.value:
|
|
for tag, row in self.process_element(value):
|
|
for p in self._output_processors.get(tag, []):
|
|
p.process_outputs(o, [row])
|
|
else:
|
|
if isinstance(self.operation, BundleOperation):
|
|
for value in o.value:
|
|
self.process_element(value)
|
|
self._main_output_processor.process_outputs(o, self.operation.finish_bundle())
|
|
else:
|
|
for value in o.value:
|
|
self._main_output_processor.process_outputs(
|
|
o, self.operation.process_element(value)
|
|
)
|
|
|
|
def monitoring_infos(self, transform_id, tag_to_pcollection_id):
|
|
"""
|
|
Only pass user metric to Java
|
|
:param tag_to_pcollection_id: useless for user metric
|
|
"""
|
|
return super().user_monitoring_infos(transform_id)
|
|
|
|
@staticmethod
|
|
def _create_output_processors(consumers_map):
|
|
def _create_processor(consumer):
|
|
if isinstance(consumer, DataOutputOperation):
|
|
return NetworkOutputProcessor(consumer)
|
|
else:
|
|
return IntermediateOutputProcessor(consumer)
|
|
|
|
return {
|
|
tag: [_create_processor(c) for c in consumers]
|
|
for tag, consumers in consumers_map.items()
|
|
}
|
|
|
|
@abstractmethod
|
|
def generate_operation(self):
|
|
pass
|
|
|
|
|
|
class StatelessFunctionOperation(FunctionOperation):
|
|
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
|
|
operator_state_backend):
|
|
super(StatelessFunctionOperation, self).__init__(
|
|
name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend
|
|
)
|
|
|
|
def generate_operation(self):
|
|
if self.operator_state_backend is not None:
|
|
return self.operation_cls(self.spec.serialized_fn, self.operator_state_backend)
|
|
else:
|
|
return self.operation_cls(self.spec.serialized_fn)
|
|
|
|
|
|
class StatefulFunctionOperation(FunctionOperation):
|
|
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
|
|
keyed_state_backend, operator_state_backend):
|
|
self._keyed_state_backend = keyed_state_backend
|
|
self._reusable_windowed_value = windowed_value.create(None, -1, None, None)
|
|
super(StatefulFunctionOperation, self).__init__(
|
|
name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend
|
|
)
|
|
|
|
def generate_operation(self):
|
|
if self.operator_state_backend is not None:
|
|
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend,
|
|
self.operator_state_backend)
|
|
else:
|
|
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend)
|
|
|
|
def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo):
|
|
# ignore timer_family_id
|
|
self.operation.add_timer_info(timer_info)
|
|
|
|
def process_timer(self, tag, timer_data):
|
|
if self._has_side_output:
|
|
# the field user_key holds the timer data
|
|
for tag, row in self.operation.process_timer(timer_data.user_key):
|
|
for p in self._output_processors.get(tag, []):
|
|
p.process_outputs(self._reusable_windowed_value, [row])
|
|
else:
|
|
self._main_output_processor.process_outputs(
|
|
self._reusable_windowed_value,
|
|
# the field user_key holds the timer data
|
|
self.operation.process_timer(timer_data.user_key),
|
|
)
|