[FLINK-14023][python] Support accessing job parameters in Python user-defined functions

This closes #21644.
pull/21604/head
Dian Fu 2 years ago
parent cc7b5329a4
commit 5cb434c36d

@ -56,6 +56,36 @@ class Predict(ScalarFunction):
predict = udf(Predict(), result_type=DataTypes.DOUBLE(), func_type="pandas")
```
## 访问作业参数
The `open()` method provides a `FunctionContext` that contains information about the context in which
user-defined functions are executed, such as the metric group, the global job parameters, etc.
The following information can be obtained by calling the corresponding methods of `FunctionContext`:
| Method | Description |
| :--------------------------------------- | :---------------------------------------------------------------------- |
| `get_metric_group()` | Metric group for this parallel subtask. |
| `get_job_parameter(name, default_value)` | Global job parameter value associated with given key. |
```python
class HashCode(ScalarFunction):
def open(self, function_context: FunctionContext):
# access the global "hashcode_factor" parameter
# "12" would be the default value if the parameter does not exist
self.factor = int(function_context.get_job_parameter("hashcode_factor", "12"))
def eval(self, s: str):
return hash(s) * self.factor
hash_code = udf(HashCode(), result_type=DataTypes.INT())
TableEnvironment t_env = TableEnvironment.create(...)
t_env.get_config().set('pipeline.global-job-parameters', 'hashcode_factor:31')
t_env.create_temporary_system_function("hashCode", hash_code)
t_env.sql_query("SELECT myField, hashCode(myField) FROM MyTable")
```
## 测试自定义函数
假如你定义了如下 Python 自定义函数:

@ -62,6 +62,36 @@ class Predict(ScalarFunction):
predict = udf(Predict(), result_type=DataTypes.DOUBLE(), func_type="pandas")
```
## Accessing job parameters
The `open()` method provides a `FunctionContext` that contains information about the context in which
user-defined functions are executed, such as the metric group, the global job parameters, etc.
The following information can be obtained by calling the corresponding methods of `FunctionContext`:
| Method | Description |
| :--------------------------------------- | :---------------------------------------------------------------------- |
| `get_metric_group()` | Metric group for this parallel subtask. |
| `get_job_parameter(name, default_value)` | Global job parameter value associated with given key. |
```python
class HashCode(ScalarFunction):
def open(self, function_context: FunctionContext):
# access the global "hashcode_factor" parameter
# "12" would be the default value if the parameter does not exist
self.factor = int(function_context.get_job_parameter("hashcode_factor", "12"))
def eval(self, s: str):
return hash(s) * self.factor
hash_code = udf(HashCode(), result_type=DataTypes.INT())
TableEnvironment t_env = TableEnvironment.create(...)
t_env.get_config().set('pipeline.global-job-parameters', 'hashcode_factor:31')
t_env.create_temporary_system_function("hashCode", hash_code)
t_env.sql_query("SELECT myField, hashCode(myField) FROM MyTable")
```
## Testing User-Defined Functions
Suppose you have defined a Python user-defined function as following:

File diff suppressed because one or more lines are too long

@ -48,7 +48,7 @@ class MetricTests(PyFlinkTestCase):
self.assertEqual(MetricTests.print_metric_group_path(new_group), 'root.key.value')
def test_metric_not_enabled(self):
fc = FunctionContext(None)
fc = FunctionContext(None, None)
with self.assertRaises(RuntimeError):
fc.get_metric_group()

@ -83,6 +83,7 @@ class BaseOperation(Operation):
else:
self.base_metric_group = None
self.func, self.user_defined_funcs = self.generate_func(serialized_fn)
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
def finish(self):
self._update_gauge(self.base_metric_group)
@ -102,7 +103,7 @@ class BaseOperation(Operation):
def open(self):
for user_defined_func in self.user_defined_funcs:
if hasattr(user_defined_func, 'open'):
user_defined_func.open(FunctionContext(self.base_metric_group))
user_defined_func.open(FunctionContext(self.base_metric_group, self.job_parameters))
def close(self):
for user_defined_func in self.user_defined_funcs:
@ -323,11 +324,12 @@ class AbstractStreamGroupAggregateOperation(BaseStatefulOperation):
self.state_cache_size = serialized_fn.state_cache_size
self.state_cleaning_enabled = serialized_fn.state_cleaning_enabled
self.data_view_specs = extract_data_view_specs(serialized_fn.udfs)
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
super(AbstractStreamGroupAggregateOperation, self).__init__(
serialized_fn, keyed_state_backend)
def open(self):
self.group_agg_function.open(FunctionContext(self.base_metric_group))
self.group_agg_function.open(FunctionContext(self.base_metric_group, self.job_parameters))
def close(self):
self.group_agg_function.close()

@ -25,6 +25,11 @@ package org.apache.flink.fn_execution.v1;
option java_package = "org.apache.flink.fnexecution.v1";
option java_outer_classname = "FlinkFnApi";
message JobParameter {
string key = 1;
string value = 2;
}
// ------------------------------------------------------------------------
// Table API & SQL
// ------------------------------------------------------------------------
@ -65,6 +70,7 @@ message UserDefinedFunctions {
bool metric_enabled = 2;
repeated OverWindow windows = 3;
bool profile_enabled = 4;
repeated JobParameter job_parameters = 5;
}
// Used to describe the info of over window in pandas batch over window aggregation
@ -182,6 +188,7 @@ message UserDefinedAggregateFunctions {
GroupWindow group_window = 12;
bool profile_enabled = 13;
repeated JobParameter job_parameters = 14;
}
// A representation of the data schema.
@ -362,11 +369,6 @@ message UserDefinedDataStreamFunction {
REVISE_OUTPUT = 100;
}
message JobParameter {
string key = 1;
string value = 2;
}
message RuntimeContext {
string task_name = 1;
string task_name_with_subtasks = 2;

@ -26,7 +26,7 @@ import pytz
from pyflink.table import DataTypes, expressions as expr
from pyflink.table.expressions import call
from pyflink.table.udf import ScalarFunction, udf
from pyflink.table.udf import ScalarFunction, udf, FunctionContext
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, \
PyFlinkBatchTableTestCase
@ -41,12 +41,15 @@ class UserDefinedFunctionTests(object):
def test_scalar_function(self):
# test metric disabled.
self.t_env.get_config().set('python.metric.enabled', 'false')
self.t_env.get_config().set('pipeline.global-job-parameters', 'subtract_value:2')
# test lambda function
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
# test Python ScalarFunction
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
subtract_two = udf(SubtractWithParameters(), result_type=DataTypes.BIGINT())
# test callable function
add_one_callable = udf(CallablePlus(), result_type=DataTypes.BIGINT())
@ -68,7 +71,7 @@ class UserDefinedFunctionTests(object):
sink_table = generate_random_table_name()
sink_table_ddl = f"""
CREATE TABLE {sink_table}(a BIGINT, b BIGINT, c BIGINT, d BIGINT, e BIGINT, f BIGINT,
g BIGINT) WITH ('connector'='test-sink')
g BIGINT, h BIGINT) WITH ('connector'='test-sink')
"""
self.t_env.execute_sql(sink_table_ddl)
@ -76,11 +79,16 @@ class UserDefinedFunctionTests(object):
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
t.where(add_one(t.b) <= 3).select(
add_one(t.a), subtract_one(t.b), add(t.a, t.c), add_one_callable(t.a),
add_one_partial(t.a), check_memory_limit(execution_mode), t.a) \
.execute_insert(sink_table).wait()
add_one(t.a),
subtract_one(t.b),
subtract_two(t.b),
add(t.a, t.c),
add_one_callable(t.a),
add_one_partial(t.a),
check_memory_limit(execution_mode),
t.a).execute_insert(sink_table).wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[2, 1, 4, 2, 2, 1, 1]", "+I[4, 0, 12, 4, 4, 1, 3]"])
self.assert_equals(actual, ["+I[2, 1, 0, 4, 2, 2, 1, 1]", "+I[4, 0, -1, 12, 4, 4, 1, 3]"])
def test_chaining_scalar_function(self):
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
@ -1010,6 +1018,15 @@ class SubtractOne(ScalarFunction):
return i - 1
class SubtractWithParameters(ScalarFunction):
def open(self, function_context: FunctionContext):
self.subtract_value = int(function_context.get_job_parameter("subtract_value", "1"))
def eval(self, i):
return i - self.subtract_value
class SubtractWithMetrics(ScalarFunction, unittest.TestCase):
def open(self, function_context):

@ -37,8 +37,9 @@ class FunctionContext(object):
and global job parameters, etc.
"""
def __init__(self, base_metric_group):
def __init__(self, base_metric_group, job_parameters):
self._base_metric_group = base_metric_group
self._job_parameters = job_parameters
def get_metric_group(self) -> MetricGroup:
"""
@ -51,6 +52,18 @@ class FunctionContext(object):
"metric with the 'python.metric.enabled' configuration.")
return self._base_metric_group
def get_job_parameter(self, key: str, default_value: str) -> str:
"""
Gets the global job parameter value associated with the given key as a string.
:param key: The key pointing to the associated value.
:param default_value: The default value which is returned in case global job parameter is
null or there is no value associated with the given key.
.. versionadded:: 1.17.0
"""
return self._job_parameters[key] if key in self._job_parameters else default_value
class UserDefinedFunction(abc.ABC):
"""

@ -134,6 +134,7 @@ public enum ProtoUtils {
// function utilities
public static FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto(
RuntimeContext runtimeContext,
PythonFunctionInfo[] userDefinedFunctions,
boolean isMetricEnabled,
boolean isProfileEnabled) {
@ -144,6 +145,16 @@ public enum ProtoUtils {
}
builder.setMetricEnabled(isMetricEnabled);
builder.setProfileEnabled(isProfileEnabled);
builder.addAllJobParameters(
runtimeContext.getExecutionConfig().getGlobalJobParameters().toMap().entrySet()
.stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()));
return builder.build();
}
@ -259,8 +270,7 @@ public enum ProtoUtils {
.entrySet().stream()
.map(
entry ->
FlinkFnApi.UserDefinedDataStreamFunction
.JobParameter.newBuilder()
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
@ -269,8 +279,7 @@ public enum ProtoUtils {
internalParameters.entrySet().stream()
.map(
entry ->
FlinkFnApi.UserDefinedDataStreamFunction
.JobParameter.newBuilder()
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())

@ -253,6 +253,16 @@ public abstract class AbstractPythonStreamAggregateOperator
ProtoUtils.createUserDefinedAggregateFunctionProto(
aggregateFunctions[i], specs));
}
builder.addAllJobParameters(
getRuntimeContext().getExecutionConfig().getGlobalJobParameters().toMap().entrySet()
.stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()));
return builder.build();
}

@ -150,6 +150,7 @@ public abstract class AbstractArrowPythonAggregateFunctionOperator
@Override
public FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto() {
return ProtoUtils.createUserDefinedFunctionsProto(
getRuntimeContext(),
pandasAggFunctions,
config.get(PYTHON_METRIC_ENABLED),
config.get(PYTHON_PROFILE_ENABLED));

@ -36,6 +36,7 @@ import org.apache.flink.table.types.logical.RowType;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.stream.Collectors;
import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED;
@ -263,6 +264,16 @@ public class BatchArrowPythonOverWindowAggregateFunctionOperator
}
builder.setMetricEnabled(config.get(PYTHON_METRIC_ENABLED));
builder.setProfileEnabled(config.get(PYTHON_PROFILE_ENABLED));
builder.addAllJobParameters(
getRuntimeContext().getExecutionConfig().getGlobalJobParameters().toMap().entrySet()
.stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()));
// add windows
for (int i = 0; i < lowerBoundary.length; i++) {
FlinkFnApi.OverWindow.Builder windowBuilder = FlinkFnApi.OverWindow.newBuilder();

@ -119,6 +119,7 @@ public abstract class AbstractPythonScalarFunctionOperator
@Override
public FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto() {
return ProtoUtils.createUserDefinedFunctionsProto(
getRuntimeContext(),
scalarFunctions,
config.get(PYTHON_METRIC_ENABLED),
config.get(PYTHON_PROFILE_ENABLED));

@ -119,6 +119,7 @@ public class EmbeddedPythonScalarFunctionOperator
interpreter.set(
"proto",
ProtoUtils.createUserDefinedFunctionsProto(
getRuntimeContext(),
scalarFunctions,
config.get(PYTHON_METRIC_ENABLED),
config.get(PYTHON_PROFILE_ENABLED))

@ -111,6 +111,7 @@ public class EmbeddedPythonTableFunctionOperator extends AbstractEmbeddedStatele
interpreter.set(
"proto",
ProtoUtils.createUserDefinedFunctionsProto(
getRuntimeContext(),
new PythonFunctionInfo[] {tableFunction},
config.get(PYTHON_METRIC_ENABLED),
config.get(PYTHON_PROFILE_ENABLED))

@ -157,6 +157,7 @@ public class PythonTableFunctionOperator
@Override
public FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto() {
return ProtoUtils.createUserDefinedFunctionsProto(
getRuntimeContext(),
new PythonFunctionInfo[] {tableFunction},
config.get(PYTHON_METRIC_ENABLED),
config.get(PYTHON_PROFILE_ENABLED));

Loading…
Cancel
Save