You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
flink/flink-python/pyflink/java_gateway.py

201 lines
8.3 KiB
Python

################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import importlib
import os
import shlex
import shutil
import struct
import tempfile
import time
from logging import WARN
from threading import RLock
from py4j.java_gateway import (java_import, logger, JavaGateway, GatewayParameters,
CallbackServerParameters)
from pyflink.find_flink_home import _find_flink_home
from pyflink.pyflink_gateway_server import launch_gateway_server_process
from pyflink.util.exceptions import install_exception_handler, install_py4j_hooks
_gateway = None
_lock = RLock()
def is_launch_gateway_disabled():
if "PYFLINK_GATEWAY_DISABLED" in os.environ \
and os.environ["PYFLINK_GATEWAY_DISABLED"].lower() not in ["0", "false", ""]:
return True
else:
return False
def get_gateway():
# type: () -> JavaGateway
global _gateway
global _lock
with _lock:
if _gateway is None:
# Set the level to WARN to mute the noisy INFO level logs
logger.level = WARN
# if Java Gateway is already running
if 'PYFLINK_GATEWAY_PORT' in os.environ:
gateway_port = int(os.environ['PYFLINK_GATEWAY_PORT'])
gateway_param = GatewayParameters(port=gateway_port, auto_convert=True)
_gateway = JavaGateway(
gateway_parameters=gateway_param,
callback_server_parameters=CallbackServerParameters(
port=0, daemonize=True, daemonize_connections=True))
else:
_gateway = launch_gateway()
callback_server = _gateway.get_callback_server()
callback_server_listening_address = callback_server.get_listening_address()
callback_server_listening_port = callback_server.get_listening_port()
_gateway.jvm.org.apache.flink.client.python.PythonEnvUtils.resetCallbackClient(
_gateway.java_gateway_server,
callback_server_listening_address,
callback_server_listening_port)
# import the flink view
import_flink_view(_gateway)
install_exception_handler()
install_py4j_hooks()
_gateway.entry_point.put("PythonFunctionFactory", PythonFunctionFactory())
_gateway.entry_point.put("Watchdog", Watchdog())
return _gateway
def launch_gateway():
# type: () -> JavaGateway
"""
launch jvm gateway
"""
if is_launch_gateway_disabled():
raise Exception("It's launching the PythonGatewayServer during Python UDF execution "
"which is unexpected. It usually happens when the job codes are "
"in the top level of the Python script file and are not enclosed in a "
"`if name == 'main'` statement.")
args = ['-c', 'org.apache.flink.client.python.PythonGatewayServer']
submit_args = os.environ.get("SUBMIT_ARGS", "local")
args += shlex.split(submit_args)
# Create a temporary directory where the gateway server should write the connection information.
conn_info_dir = tempfile.mkdtemp()
try:
fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
os.close(fd)
os.unlink(conn_info_file)
_find_flink_home()
env = dict(os.environ)
env["_PYFLINK_CONN_INFO_PATH"] = conn_info_file
p = launch_gateway_server_process(env, args)
while not p.poll() and not os.path.isfile(conn_info_file):
time.sleep(0.1)
if not os.path.isfile(conn_info_file):
stderr_info = p.stderr.read().decode('utf-8')
raise RuntimeError(
"Java gateway process exited before sending its port number.\nStderr:\n"
+ stderr_info
)
with open(conn_info_file, "rb") as info:
gateway_port = struct.unpack("!I", info.read(4))[0]
finally:
shutil.rmtree(conn_info_dir)
# Connect to the gateway
gateway = JavaGateway(
gateway_parameters=GatewayParameters(port=gateway_port, auto_convert=True),
callback_server_parameters=CallbackServerParameters(
port=0, daemonize=True, daemonize_connections=True))
return gateway
def import_flink_view(gateway):
"""
import the classes used by PyFlink.
:param gateway:gateway connected to JavaGateWayServer
"""
# Import the classes used by PyFlink
java_import(gateway.jvm, "org.apache.flink.table.api.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.api.*")
java_import(gateway.jvm, "org.apache.flink.table.api.config.*")
java_import(gateway.jvm, "org.apache.flink.table.api.java.*")
java_import(gateway.jvm, "org.apache.flink.table.api.bridge.java.*")
java_import(gateway.jvm, "org.apache.flink.table.api.dataview.*")
java_import(gateway.jvm, "org.apache.flink.table.catalog.*")
java_import(gateway.jvm, "org.apache.flink.table.descriptors.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.descriptors.*")
java_import(gateway.jvm, "org.apache.flink.table.descriptors.python.*")
java_import(gateway.jvm, "org.apache.flink.table.expressions.*")
java_import(gateway.jvm, "org.apache.flink.table.sources.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.sources.*")
java_import(gateway.jvm, "org.apache.flink.table.sinks.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.sinks.*")
java_import(gateway.jvm, "org.apache.flink.table.sources.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.sources.*")
java_import(gateway.jvm, "org.apache.flink.table.types.*")
java_import(gateway.jvm, "org.apache.flink.table.types.logical.*")
java_import(gateway.jvm, "org.apache.flink.table.legacy.types.logical.*")
java_import(gateway.jvm, "org.apache.flink.table.util.python.*")
java_import(gateway.jvm, "org.apache.flink.api.common.python.*")
java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.TypeInformation")
java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.Types")
java_import(gateway.jvm, "org.apache.flink.api.java.ExecutionEnvironment")
java_import(gateway.jvm,
"org.apache.flink.streaming.api.environment.StreamExecutionEnvironment")
java_import(gateway.jvm, "org.apache.flink.python.util.PythonDependencyUtils")
java_import(gateway.jvm, "org.apache.flink.python.PythonOptions")
java_import(gateway.jvm, "org.apache.flink.client.python.PythonGatewayServer")
java_import(gateway.jvm, "org.apache.flink.streaming.api.functions.python.*")
java_import(gateway.jvm, "org.apache.flink.streaming.api.operators.python.process.*")
java_import(gateway.jvm, "org.apache.flink.streaming.api.operators.python.embedded.*")
java_import(gateway.jvm, "org.apache.flink.streaming.api.typeinfo.python.*")
class PythonFunctionFactory(object):
"""
Used to create PythonFunction objects for Java jobs.
"""
def getPythonFunction(self, moduleName, objectName):
udf_wrapper = getattr(importlib.import_module(moduleName), objectName)
return udf_wrapper._java_user_defined_function()
class Java:
implements = ["org.apache.flink.client.python.PythonFunctionFactory"]
class Watchdog(object):
"""
Used to provide to Java side to check whether its parent process is alive.
"""
def ping(self):
time.sleep(10)
return True
class Java:
implements = ["org.apache.flink.client.python.PythonGatewayServer$Watchdog"]