You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
flink/flink-python/pyflink/fn_execution/state_impl.py

1340 lines
53 KiB
Python

################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import base64
import collections
from abc import ABC, abstractmethod
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker.bundle_processor import SynchronousBagRuntimeState
from apache_beam.transforms import userstate
from enum import Enum
from functools import partial
from io import BytesIO
from typing import List, Tuple, Any, Dict, Collection, cast
from pyflink.datastream import ReduceFunction
from pyflink.datastream.functions import AggregateFunction
from pyflink.datastream.state import StateTtlConfig, MapStateDescriptor, OperatorStateStore
from pyflink.fn_execution.beam.beam_coders import FlinkCoder
from pyflink.fn_execution.coders import FieldCoder, MapCoder, from_type_info
from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor as pb2_StateDescriptor
from pyflink.fn_execution.internal_state import InternalKvState, N, InternalValueState, \
InternalListState, InternalReducingState, InternalMergingState, InternalAggregatingState, \
InternalMapState, InternalReadOnlyBroadcastState, InternalBroadcastState
class LRUCache(object):
"""
A simple LRUCache implementation used to manage the internal runtime state.
An internal runtime state is used to handle the data under a specific key of a "public" state.
So the number of the internal runtime states may keep growing during the streaming task
execution. To prevent the OOM caused by the unlimited growth, we introduce this LRUCache
to evict the inactive internal runtime states.
"""
def __init__(self, max_entries, default_entry):
self._max_entries = max_entries
self._default_entry = default_entry
self._cache = collections.OrderedDict()
self._on_evict = None
def get(self, key):
value = self._cache.pop(key, self._default_entry)
if value != self._default_entry:
# update the last access time
self._cache[key] = value
return value
def put(self, key, value):
self._cache[key] = value
while len(self._cache) > self._max_entries:
name, value = self._cache.popitem(last=False)
if self._on_evict is not None:
self._on_evict(name, value)
def evict(self, key):
value = self._cache.pop(key, self._default_entry)
if self._on_evict is not None:
self._on_evict(key, value)
def evict_all(self):
if self._on_evict is not None:
for item in self._cache.items():
self._on_evict(*item)
self._cache.clear()
def set_on_evict(self, func):
self._on_evict = func
def __len__(self):
return len(self._cache)
def __iter__(self):
return iter(self._cache.values())
def __contains__(self, key):
return key in self._cache
class SynchronousKvRuntimeState(InternalKvState, ABC):
"""
Base Class for partitioned State implementation.
"""
def __init__(self, name: str, remote_state_backend: 'RemoteKeyedStateBackend'):
self.name = name
self._remote_state_backend = remote_state_backend
self._internal_state = None
self.namespace = None
self._ttl_config = None
self._cache_type = SynchronousKvRuntimeState.CacheType.ENABLE_READ_WRITE_CACHE
def set_current_namespace(self, namespace: N) -> None:
if namespace == self.namespace:
return
if self.namespace is not None:
self._remote_state_backend.cache_internal_state(
self._remote_state_backend._encoded_current_key, self)
self.namespace = namespace
self._internal_state = None
def enable_time_to_live(self, ttl_config: StateTtlConfig):
self._ttl_config = ttl_config
if ttl_config.get_state_visibility() == StateTtlConfig.StateVisibility.NeverReturnExpired:
self._cache_type = SynchronousKvRuntimeState.CacheType.DISABLE_CACHE
elif ttl_config.get_update_type() == StateTtlConfig.UpdateType.OnReadAndWrite:
self._cache_type = SynchronousKvRuntimeState.CacheType.ENABLE_WRITE_CACHE
if self._cache_type != SynchronousKvRuntimeState.CacheType.ENABLE_READ_WRITE_CACHE:
# disable read cache
self._remote_state_backend._state_handler._state_cache._cache._max_entries = 0
@abstractmethod
def get_internal_state(self):
pass
class CacheType(Enum):
DISABLE_CACHE = 0
ENABLE_WRITE_CACHE = 1
ENABLE_READ_WRITE_CACHE = 2
class SynchronousBagKvRuntimeState(SynchronousKvRuntimeState, ABC):
"""
Base Class for State implementation backed by a :class:`SynchronousBagRuntimeState`.
"""
def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
super(SynchronousBagKvRuntimeState, self).__init__(name, remote_state_backend)
self._value_coder = value_coder
def get_internal_state(self):
if self._internal_state is None:
self._internal_state = self._remote_state_backend._get_internal_bag_state(
self.name, self.namespace, self._value_coder, self._ttl_config)
return self._internal_state
def _maybe_clear_write_cache(self):
if self._cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE or \
self._remote_state_backend._state_cache_size <= 0:
self._internal_state.commit()
self._internal_state._cleared = False
self._internal_state._added_elements = []
class SynchronousValueRuntimeState(SynchronousBagKvRuntimeState, InternalValueState):
"""
The runtime ValueState implementation backed by a :class:`SynchronousBagRuntimeState`.
"""
def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
super(SynchronousValueRuntimeState, self).__init__(name, value_coder, remote_state_backend)
def value(self):
for i in self.get_internal_state().read():
return i
return None
def update(self, value) -> None:
self.get_internal_state()
self._internal_state.clear()
self._internal_state.add(value)
self._maybe_clear_write_cache()
def clear(self) -> None:
self.get_internal_state().clear()
class SynchronousMergingRuntimeState(SynchronousBagKvRuntimeState, InternalMergingState, ABC):
"""
Base Class for MergingState implementation.
"""
def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
super(SynchronousMergingRuntimeState, self).__init__(
name, value_coder, remote_state_backend)
def merge_namespaces(self, target: N, sources: Collection[N]) -> None:
self._remote_state_backend.merge_namespaces(self, target, sources, self._ttl_config)
class SynchronousListRuntimeState(SynchronousMergingRuntimeState, InternalListState):
"""
The runtime ListState implementation backed by a :class:`SynchronousBagRuntimeState`.
"""
def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
super(SynchronousListRuntimeState, self).__init__(name, value_coder, remote_state_backend)
def add(self, v):
self.get_internal_state().add(v)
self._maybe_clear_write_cache()
def get(self):
return self.get_internal_state().read()
def add_all(self, values):
self.get_internal_state()._added_elements.extend(values)
self._maybe_clear_write_cache()
def update(self, values):
self.clear()
self.add_all(values)
self._maybe_clear_write_cache()
def clear(self):
self.get_internal_state().clear()
class SynchronousReducingRuntimeState(SynchronousMergingRuntimeState, InternalReducingState):
"""
The runtime ReducingState implementation backed by a :class:`SynchronousBagRuntimeState`.
"""
def __init__(self,
name: str,
value_coder,
remote_state_backend: 'RemoteKeyedStateBackend',
reduce_function: ReduceFunction):
super(SynchronousReducingRuntimeState, self).__init__(
name, value_coder, remote_state_backend)
self._reduce_function = reduce_function
def add(self, v):
current_value = self.get()
if current_value is None:
self._internal_state.add(v)
else:
self._internal_state.clear()
self._internal_state.add(self._reduce_function.reduce(current_value, v))
self._maybe_clear_write_cache()
def get(self):
for i in self.get_internal_state().read():
return i
return None
def clear(self):
self.get_internal_state().clear()
class SynchronousAggregatingRuntimeState(SynchronousMergingRuntimeState, InternalAggregatingState):
"""
The runtime AggregatingState implementation backed by a :class:`SynchronousBagRuntimeState`.
"""
def __init__(self,
name: str,
value_coder,
remote_state_backend: 'RemoteKeyedStateBackend',
agg_function: AggregateFunction):
super(SynchronousAggregatingRuntimeState, self).__init__(
name, value_coder, remote_state_backend)
self._agg_function = agg_function
def add(self, v):
if v is None:
self.clear()
return
accumulator = self._get_accumulator()
if accumulator is None:
accumulator = self._agg_function.create_accumulator()
accumulator = self._agg_function.add(v, accumulator)
self._internal_state.clear()
self._internal_state.add(accumulator)
self._maybe_clear_write_cache()
def get(self):
accumulator = self._get_accumulator()
if accumulator is None:
return None
else:
return self._agg_function.get_result(accumulator)
def _get_accumulator(self):
for i in self.get_internal_state().read():
return i
return None
def clear(self):
self.get_internal_state().clear()
class CachedMapState(LRUCache):
def __init__(self, max_entries):
super(CachedMapState, self).__init__(max_entries, None)
self._all_data_cached = False
self._cached_keys = set()
def on_evict(key, value):
if value[0]:
self._cached_keys.remove(key)
self._all_data_cached = False
self.set_on_evict(on_evict)
def set_all_data_cached(self):
self._all_data_cached = True
def is_all_data_cached(self):
return self._all_data_cached
def put(self, key, exists_and_value):
if exists_and_value[0]:
self._cached_keys.add(key)
super(CachedMapState, self).put(key, exists_and_value)
def get_cached_keys(self):
return self._cached_keys
class IterateType(Enum):
ITEMS = 0
KEYS = 1
VALUES = 2
class IteratorToken(Enum):
"""
The token indicates the status of current underlying iterator. It can also be a UUID,
which represents an iterator on the Java side.
"""
NOT_START = 0
FINISHED = 1
def create_cache_iterator(cache_dict, iterate_type, iterated_keys=None):
if iterated_keys is None:
iterated_keys = []
if iterate_type == IterateType.KEYS:
for key, (exists, value) in cache_dict.items():
if not exists or key in iterated_keys:
continue
yield key, key
elif iterate_type == IterateType.VALUES:
for key, (exists, value) in cache_dict.items():
if not exists or key in iterated_keys:
continue
yield key, value
elif iterate_type == IterateType.ITEMS:
for key, (exists, value) in cache_dict.items():
if not exists or key in iterated_keys:
continue
yield key, (key, value)
else:
raise Exception("Unsupported iterate type: %s" % iterate_type)
class CachingMapStateHandler(object):
# GET request flags
GET_FLAG = 0
ITERATE_FLAG = 1
CHECK_EMPTY_FLAG = 2
# GET response flags
EXIST_FLAG = 0
IS_NONE_FLAG = 1
NOT_EXIST_FLAG = 2
IS_EMPTY_FLAG = 3
NOT_EMPTY_FLAG = 4
# APPEND request flags
DELETE = 0
SET_NONE = 1
SET_VALUE = 2
def __init__(self, caching_state_handler, max_cached_map_key_entries):
self._state_cache = caching_state_handler._state_cache
self._underlying = caching_state_handler._underlying
self._context = caching_state_handler._context
self._max_cached_map_key_entries = max_cached_map_key_entries
self._cached_iterator_num = 0
def _get_cache_token(self):
if not self._state_cache.is_cache_enabled():
return None
if self._context.user_state_cache_token:
return self._context.user_state_cache_token
else:
return self._context.bundle_cache_token
def blocking_get(self, state_key, map_key, map_key_encoder, map_value_decoder):
cache_token = self._get_cache_token()
if not cache_token:
# cache disabled / no cache token, request from remote directly
return self._get_raw(state_key, map_key, map_key_encoder, map_value_decoder)
# lookup cache first
cache_state_key = self._convert_to_cache_key(state_key)
cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
if cached_map_state is None:
# request from remote
exists, value = self._get_raw(state_key, map_key, map_key_encoder, map_value_decoder)
cached_map_state = CachedMapState(self._max_cached_map_key_entries)
cached_map_state.put(map_key, (exists, value))
self._state_cache.put((cache_state_key, cache_token), cached_map_state)
return exists, value
else:
cached_value = cached_map_state.get(map_key)
if cached_value is None:
if cached_map_state.is_all_data_cached():
return False, None
# request from remote
exists, value = self._get_raw(
state_key, map_key, map_key_encoder, map_value_decoder)
cached_map_state.put(map_key, (exists, value))
return exists, value
else:
return cached_value
def lazy_iterator(self, state_key, iterate_type, map_key_decoder, map_value_decoder,
iterated_keys):
cache_token = self._get_cache_token()
if cache_token:
# check if the data in the read cache can be used
cache_state_key = self._convert_to_cache_key(state_key)
cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
if cached_map_state and cached_map_state.is_all_data_cached():
return create_cache_iterator(
cached_map_state._cache, iterate_type, iterated_keys)
# request from remote
last_iterator_token = IteratorToken.NOT_START
current_batch, iterator_token = self._iterate_raw(
state_key, iterate_type,
last_iterator_token,
map_key_decoder,
map_value_decoder)
if cache_token and \
iterator_token == IteratorToken.FINISHED and \
iterate_type != IterateType.KEYS and \
self._max_cached_map_key_entries >= len(current_batch):
# Special case: all the data of the map state is contained in current batch,
# and can be stored in the cached map state.
cached_map_state = CachedMapState(self._max_cached_map_key_entries)
cache_state_key = self._convert_to_cache_key(state_key)
for key, value in current_batch.items():
cached_map_state.put(key, (True, value))
cached_map_state.set_all_data_cached()
self._state_cache.put((cache_state_key, cache_token), cached_map_state)
return self._lazy_remote_iterator(
state_key,
iterate_type,
map_key_decoder,
map_value_decoder,
iterated_keys,
iterator_token,
current_batch)
def _lazy_remote_iterator(
self,
state_key,
iterate_type,
map_key_decoder,
map_value_decoder,
iterated_keys,
iterator_token,
current_batch):
if iterate_type == IterateType.KEYS:
while True:
for key in current_batch:
if key in iterated_keys:
continue
yield key, key
if iterator_token == IteratorToken.FINISHED:
break
current_batch, iterator_token = self._iterate_raw(
state_key,
iterate_type,
iterator_token,
map_key_decoder,
map_value_decoder)
elif iterate_type == IterateType.VALUES:
while True:
for key, value in current_batch.items():
if key in iterated_keys:
continue
yield key, value
if iterator_token == IteratorToken.FINISHED:
break
current_batch, iterator_token = self._iterate_raw(
state_key,
iterate_type,
iterator_token,
map_key_decoder,
map_value_decoder)
elif iterate_type == IterateType.ITEMS:
while True:
for key, value in current_batch.items():
if key in iterated_keys:
continue
yield key, (key, value)
if iterator_token == IteratorToken.FINISHED:
break
current_batch, iterator_token = self._iterate_raw(
state_key,
iterate_type,
iterator_token,
map_key_decoder,
map_value_decoder)
else:
raise Exception("Unsupported iterate type: %s" % iterate_type)
def extend(self, state_key, items: List[Tuple[int, Any, Any]],
map_key_encoder, map_value_encoder):
cache_token = self._get_cache_token()
if cache_token:
# Cache lookup
cache_state_key = self._convert_to_cache_key(state_key)
cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
if cached_map_state is None:
cached_map_state = CachedMapState(self._max_cached_map_key_entries)
self._state_cache.put((cache_state_key, cache_token), cached_map_state)
for request_flag, map_key, map_value in items:
if request_flag == self.DELETE:
cached_map_state.put(map_key, (False, None))
elif request_flag == self.SET_NONE:
cached_map_state.put(map_key, (True, None))
elif request_flag == self.SET_VALUE:
cached_map_state.put(map_key, (True, map_value))
else:
raise Exception("Unknown flag: " + str(request_flag))
return self._append_raw(
state_key,
items,
map_key_encoder,
map_value_encoder)
def check_empty(self, state_key):
cache_token = self._get_cache_token()
if cache_token:
# Cache lookup
cache_state_key = self._convert_to_cache_key(state_key)
cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
if cached_map_state is not None:
if cached_map_state.is_all_data_cached() and \
len(cached_map_state.get_cached_keys()) == 0:
return True
elif len(cached_map_state.get_cached_keys()) > 0:
return False
return self._check_empty_raw(state_key)
def clear(self, state_key):
self.clear_read_cache(state_key)
return self._underlying.clear(state_key)
def clear_read_cache(self, state_key):
cache_token = self._get_cache_token()
if cache_token:
cache_key = self._convert_to_cache_key(state_key)
self._state_cache.invalidate((cache_key, cache_token))
def get_cached_iterators_num(self):
return self._cached_iterator_num
def _inc_cached_iterators_num(self):
self._cached_iterator_num += 1
def _dec_cached_iterators_num(self):
self._cached_iterator_num -= 1
def reset_cached_iterators_num(self):
self._cached_iterator_num = 0
def _check_empty_raw(self, state_key):
output_stream = coder_impl.create_OutputStream()
output_stream.write_byte(self.CHECK_EMPTY_FLAG)
continuation_token = output_stream.get()
data, response_token = self._underlying.get_raw(state_key, continuation_token)
if data[0] == self.IS_EMPTY_FLAG:
return True
elif data[0] == self.NOT_EMPTY_FLAG:
return False
else:
raise Exception("Unknown response flag: " + str(data[0]))
def _get_raw(self, state_key, map_key, map_key_encoder, map_value_decoder):
output_stream = coder_impl.create_OutputStream()
output_stream.write_byte(self.GET_FLAG)
map_key_encoder(map_key, output_stream)
continuation_token = output_stream.get()
data, response_token = self._underlying.get_raw(state_key, continuation_token)
input_stream = coder_impl.create_InputStream(data)
result_flag = input_stream.read_byte()
if result_flag == self.EXIST_FLAG:
return True, map_value_decoder(input_stream)
elif result_flag == self.IS_NONE_FLAG:
return True, None
elif result_flag == self.NOT_EXIST_FLAG:
return False, None
else:
raise Exception("Unknown response flag: " + str(result_flag))
def _iterate_raw(self, state_key, iterate_type, iterator_token,
map_key_decoder, map_value_decoder):
output_stream = coder_impl.create_OutputStream()
output_stream.write_byte(self.ITERATE_FLAG)
output_stream.write_byte(iterate_type.value)
if not isinstance(iterator_token, IteratorToken):
# The iterator token represents a Java iterator
output_stream.write_bigendian_int32(len(iterator_token))
output_stream.write(iterator_token)
else:
output_stream.write_bigendian_int32(0)
continuation_token = output_stream.get()
data, response_token = self._underlying.get_raw(state_key, continuation_token)
if len(response_token) != 0:
# The new iterator token is an UUID which represents a cached iterator at Java
# side.
new_iterator_token = response_token
if iterator_token == IteratorToken.NOT_START:
# This is the first request but not the last request of current state.
# It means there is a new iterator has been created and cached at Java side.
self._inc_cached_iterators_num()
else:
new_iterator_token = IteratorToken.FINISHED
if iterator_token != IteratorToken.NOT_START:
# This is not the first request but the last request of current state.
# It means the cached iterator created at Java side has been removed as
# current iteration has finished.
self._dec_cached_iterators_num()
input_stream = coder_impl.create_InputStream(data)
if iterate_type == IterateType.ITEMS or iterate_type == IterateType.VALUES:
# decode both key and value
current_batch = {}
while input_stream.size() > 0:
key = map_key_decoder(input_stream)
is_not_none = input_stream.read_byte()
if is_not_none:
value = map_value_decoder(input_stream)
else:
value = None
current_batch[key] = value
else:
# only decode key
current_batch = []
while input_stream.size() > 0:
key = map_key_decoder(input_stream)
current_batch.append(key)
return current_batch, new_iterator_token
def _append_raw(self, state_key, items, map_key_encoder, map_value_encoder):
output_stream = coder_impl.create_OutputStream()
output_stream.write_bigendian_int32(len(items))
for request_flag, map_key, map_value in items:
output_stream.write_byte(request_flag)
# Not all the coder impls will serialize the length of bytes when we set the "nested"
# param to "True", so we need to encode the length of bytes manually.
tmp_out = coder_impl.create_OutputStream()
map_key_encoder(map_key, tmp_out)
serialized_data = tmp_out.get()
output_stream.write_bigendian_int32(len(serialized_data))
output_stream.write(serialized_data)
if request_flag == self.SET_VALUE:
tmp_out = coder_impl.create_OutputStream()
map_value_encoder(map_value, tmp_out)
serialized_data = tmp_out.get()
output_stream.write_bigendian_int32(len(serialized_data))
output_stream.write(serialized_data)
return self._underlying.append_raw(state_key, output_stream.get())
@staticmethod
def _convert_to_cache_key(state_key):
return state_key.SerializeToString()
class RemovableConcatIterator(collections.abc.Iterator):
def __init__(self, internal_map_state, first, second):
self._first = first
self._second = second
self._first_not_finished = True
self._internal_map_state = internal_map_state
self._mod_count = self._internal_map_state._mod_count
self._last_key = None
def __next__(self):
self._check_modification()
if self._first_not_finished:
try:
self._last_key, element = next(self._first)
return element
except StopIteration:
self._first_not_finished = False
return self.__next__()
else:
self._last_key, element = next(self._second)
return element
def remove(self):
"""
Remove the last element returned by this iterator.
"""
if self._last_key is None:
raise Exception("You need to call the '__next__' method before calling "
"this method.")
self._check_modification()
# Bypass the 'remove' method of the map state to avoid triggering the commit of the write
# cache.
if self._internal_map_state._cleared:
del self._internal_map_state._write_cache[self._last_key]
if len(self._internal_map_state._write_cache) == 0:
self._internal_map_state._is_empty = True
else:
self._internal_map_state._write_cache[self._last_key] = (False, None)
self._mod_count += 1
self._internal_map_state._mod_count += 1
self._last_key = None
def _check_modification(self):
if self._mod_count != self._internal_map_state._mod_count:
raise Exception("Concurrent modification detected. "
"You can not modify the map state when iterating it except using the "
"'remove' method of this iterator.")
class InternalSynchronousMapRuntimeState(object):
def __init__(self,
map_state_handler: CachingMapStateHandler,
state_key,
map_key_coder,
map_value_coder,
max_write_cache_entries):
self._map_state_handler = map_state_handler
self._state_key = state_key
self._map_key_coder = map_key_coder
if isinstance(map_key_coder, FieldCoder):
map_key_coder_impl = FlinkCoder(map_key_coder).get_impl()
else:
map_key_coder_impl = map_key_coder.get_impl()
self._map_key_encoder, self._map_key_decoder = \
self._get_encoder_and_decoder(map_key_coder_impl)
self._map_value_coder = map_value_coder
if isinstance(map_value_coder, FieldCoder):
map_value_coder_impl = FlinkCoder(map_value_coder).get_impl()
else:
map_value_coder_impl = map_value_coder.get_impl()
self._map_value_encoder, self._map_value_decoder = \
self._get_encoder_and_decoder(map_value_coder_impl)
self._write_cache = dict()
self._max_write_cache_entries = max_write_cache_entries
self._is_empty = None
self._cleared = False
self._mod_count = 0
def get(self, map_key):
if self._is_empty:
return None
if map_key in self._write_cache:
exists, value = self._write_cache[map_key]
if exists:
return value
else:
return None
if self._cleared:
return None
exists, value = self._map_state_handler.blocking_get(
self._state_key, map_key, self._map_key_encoder, self._map_value_decoder)
if exists:
return value
else:
return None
def put(self, map_key, map_value):
self._write_cache[map_key] = (True, map_value)
self._is_empty = False
self._mod_count += 1
if len(self._write_cache) >= self._max_write_cache_entries:
self.commit()
def put_all(self, dict_value):
for map_key, map_value in dict_value:
self._write_cache[map_key] = (True, map_value)
self._is_empty = False
self._mod_count += 1
if len(self._write_cache) >= self._max_write_cache_entries:
self.commit()
def remove(self, map_key):
if self._is_empty:
return
if self._cleared:
del self._write_cache[map_key]
if len(self._write_cache) == 0:
self._is_empty = True
else:
self._write_cache[map_key] = (False, None)
self._is_empty = None
self._mod_count += 1
if len(self._write_cache) >= self._max_write_cache_entries:
self.commit()
def contains(self, map_key):
if self._is_empty:
return False
if self.get(map_key) is None:
return False
else:
return True
def is_empty(self):
if self._is_empty is None:
if len(self._write_cache) > 0:
self.commit()
self._is_empty = self._map_state_handler.check_empty(self._state_key)
return self._is_empty
def clear(self):
self._cleared = True
self._is_empty = True
self._mod_count += 1
self._write_cache.clear()
def items(self):
return RemovableConcatIterator(
self,
self.write_cache_iterator(IterateType.ITEMS),
self.remote_data_iterator(IterateType.ITEMS))
def keys(self):
return RemovableConcatIterator(
self,
self.write_cache_iterator(IterateType.KEYS),
self.remote_data_iterator(IterateType.KEYS))
def values(self):
return RemovableConcatIterator(
self,
self.write_cache_iterator(IterateType.VALUES),
self.remote_data_iterator(IterateType.VALUES))
def commit(self):
to_await = None
if self._cleared:
to_await = self._map_state_handler.clear(self._state_key)
if self._write_cache:
append_items = []
for map_key, (exists, value) in self._write_cache.items():
if exists:
if value is not None:
append_items.append(
(CachingMapStateHandler.SET_VALUE, map_key, value))
else:
append_items.append((CachingMapStateHandler.SET_NONE, map_key, None))
else:
append_items.append((CachingMapStateHandler.DELETE, map_key, None))
self._write_cache.clear()
to_await = self._map_state_handler.extend(
self._state_key, append_items, self._map_key_encoder, self._map_value_encoder)
if to_await:
to_await.get()
self._write_cache.clear()
self._cleared = False
self._mod_count += 1
def write_cache_iterator(self, iterate_type):
return create_cache_iterator(self._write_cache, iterate_type)
def remote_data_iterator(self, iterate_type):
if self._cleared or self._is_empty:
return iter([])
else:
return self._map_state_handler.lazy_iterator(
self._state_key,
iterate_type,
self._map_key_decoder,
self._map_value_decoder,
self._write_cache)
@staticmethod
def _get_encoder_and_decoder(coder):
encoder = partial(coder.encode_to_stream, nested=True)
decoder = partial(coder.decode_from_stream, nested=True)
return encoder, decoder
class SynchronousMapRuntimeState(SynchronousKvRuntimeState, InternalMapState):
def __init__(self,
name: str,
map_key_coder,
map_value_coder,
remote_state_backend: 'RemoteKeyedStateBackend'):
super(SynchronousMapRuntimeState, self).__init__(name, remote_state_backend)
self._map_key_coder = map_key_coder
self._map_value_coder = map_value_coder
def get_internal_state(self):
if self._internal_state is None:
self._internal_state = self._remote_state_backend._get_internal_map_state(
self.name,
self.namespace,
self._map_key_coder,
self._map_value_coder,
self._ttl_config,
self._cache_type)
return self._internal_state
def get(self, key):
return self.get_internal_state().get(key)
def put(self, key, value):
self.get_internal_state().put(key, value)
def put_all(self, dict_value):
self.get_internal_state().put_all(dict_value)
def remove(self, key):
self.get_internal_state().remove(key)
def contains(self, key):
return self.get_internal_state().contains(key)
def items(self):
return self.get_internal_state().items()
def keys(self):
return self.get_internal_state().keys()
def values(self):
return self.get_internal_state().values()
def is_empty(self):
return self.get_internal_state().is_empty()
def clear(self):
self.get_internal_state().clear()
class RemoteKeyedStateBackend(object):
"""
A keyed state backend provides methods for managing keyed state.
"""
MERGE_NAMESAPCES_MARK = "merge_namespaces"
def __init__(self,
state_handler,
key_coder,
namespace_coder,
state_cache_size,
map_state_read_cache_size,
map_state_write_cache_size):
self._state_handler = state_handler
self._map_state_handler = CachingMapStateHandler(
state_handler, map_state_read_cache_size)
self._key_coder_impl = key_coder.get_impl()
self.namespace_coder = namespace_coder
if namespace_coder:
self._namespace_coder_impl = namespace_coder.get_impl()
else:
self._namespace_coder_impl = None
self._state_cache_size = state_cache_size
self._map_state_write_cache_size = map_state_write_cache_size
self._all_states = {} # type: Dict[str, SynchronousKvRuntimeState]
self._internal_state_cache = LRUCache(self._state_cache_size, None)
self._internal_state_cache.set_on_evict(
lambda key, value: self.commit_internal_state(value))
self._current_key = None
self._encoded_current_key = None
self._clear_iterator_mark = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id="clear_iterators",
side_input_id="clear_iterators",
key=self._encoded_current_key))
def get_list_state(self, name, element_coder, ttl_config=None):
return self._wrap_internal_bag_state(
name,
element_coder,
SynchronousListRuntimeState,
SynchronousListRuntimeState,
ttl_config)
def get_value_state(self, name, value_coder, ttl_config=None):
return self._wrap_internal_bag_state(
name,
value_coder,
SynchronousValueRuntimeState,
SynchronousValueRuntimeState,
ttl_config)
def get_map_state(self, name, map_key_coder, map_value_coder, ttl_config=None):
if name in self._all_states:
self.validate_map_state(name, map_key_coder, map_value_coder)
return self._all_states[name]
map_state = SynchronousMapRuntimeState(name, map_key_coder, map_value_coder, self)
if ttl_config is not None:
map_state.enable_time_to_live(ttl_config)
self._all_states[name] = map_state
return map_state
def get_reducing_state(self, name, coder, reduce_function, ttl_config=None):
return self._wrap_internal_bag_state(
name,
coder,
SynchronousReducingRuntimeState,
partial(SynchronousReducingRuntimeState, reduce_function=reduce_function),
ttl_config)
def get_aggregating_state(self, name, coder, agg_function, ttl_config=None):
return self._wrap_internal_bag_state(
name,
coder,
SynchronousAggregatingRuntimeState,
partial(SynchronousAggregatingRuntimeState, agg_function=agg_function),
ttl_config)
def validate_state(self, name, coder, expected_type):
if name in self._all_states:
state = self._all_states[name]
if not isinstance(state, expected_type):
raise Exception("The state name '%s' is already in use and not a %s."
% (name, expected_type))
if state._value_coder != coder:
raise Exception("State name corrupted: %s" % name)
def validate_map_state(self, name, map_key_coder, map_value_coder):
if name in self._all_states:
state = self._all_states[name]
if not isinstance(state, SynchronousMapRuntimeState):
raise Exception("The state name '%s' is already in use and not a map state."
% name)
if state._map_key_coder != map_key_coder or \
state._map_value_coder != map_value_coder:
raise Exception("State name corrupted: %s" % name)
def _wrap_internal_bag_state(
self, name, element_coder, wrapper_type, wrap_method, ttl_config):
if name in self._all_states:
self.validate_state(name, element_coder, wrapper_type)
return self._all_states[name]
wrapped_state = wrap_method(name, element_coder, self)
if ttl_config is not None:
wrapped_state.enable_time_to_live(ttl_config)
self._all_states[name] = wrapped_state
return wrapped_state
def _get_internal_bag_state(self, name, namespace, element_coder, ttl_config):
encoded_namespace = self._encode_namespace(namespace)
cached_state = self._internal_state_cache.get(
(name, self._encoded_current_key, encoded_namespace))
if cached_state is not None:
return cached_state
# The created internal state would not be put into the internal state cache
# at once. The internal state cache is only updated when the current key changes.
# The reason is that the state cache size may be smaller that the count of activated
# state (i.e. the state with current key).
if isinstance(element_coder, FieldCoder):
element_coder = FlinkCoder(element_coder)
state_spec = userstate.BagStateSpec(name, element_coder)
internal_state = self._create_bag_state(state_spec, encoded_namespace, ttl_config)
return internal_state
def _get_internal_map_state(
self, name, namespace, map_key_coder, map_value_coder, ttl_config, cache_type):
encoded_namespace = self._encode_namespace(namespace)
cached_state = self._internal_state_cache.get(
(name, self._encoded_current_key, encoded_namespace))
if cached_state is not None:
return cached_state
internal_map_state = self._create_internal_map_state(
name, encoded_namespace, map_key_coder, map_value_coder, ttl_config, cache_type)
return internal_map_state
def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace, ttl_config) \
-> userstate.AccumulatingRuntimeState:
if isinstance(state_spec, userstate.BagStateSpec):
bag_state = SynchronousBagRuntimeState(
self._state_handler,
state_key=self.get_bag_state_key(
state_spec.name, self._encoded_current_key, encoded_namespace, ttl_config),
value_coder=state_spec.coder)
return bag_state
else:
raise NotImplementedError(state_spec)
def _create_internal_map_state(
self, name, encoded_namespace, map_key_coder, map_value_coder, ttl_config, cache_type):
# Currently the `beam_fn_api.proto` does not support MapState, so we use the
# the `MultimapSideInput` message to mark the state as a MapState for now.
state_proto = pb2_StateDescriptor()
state_proto.state_name = name
if ttl_config is not None:
state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto())
state_key = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id="",
window=encoded_namespace,
side_input_id=base64.b64encode(state_proto.SerializeToString()),
key=self._encoded_current_key))
if cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE:
write_cache_size = 0
else:
write_cache_size = self._map_state_write_cache_size
return InternalSynchronousMapRuntimeState(
self._map_state_handler,
state_key,
map_key_coder,
map_value_coder,
write_cache_size)
def _encode_namespace(self, namespace):
if namespace is not None:
encoded_namespace = self._namespace_coder_impl.encode(namespace)
else:
encoded_namespace = b''
return encoded_namespace
def cache_internal_state(self, encoded_key, internal_kv_state: SynchronousKvRuntimeState):
encoded_old_namespace = self._encode_namespace(internal_kv_state.namespace)
self._internal_state_cache.put(
(internal_kv_state.name, encoded_key, encoded_old_namespace),
internal_kv_state.get_internal_state())
def set_current_key(self, key):
if key == self._current_key:
return
encoded_old_key = self._encoded_current_key
for state_name, state_obj in self._all_states.items():
if self._state_cache_size > 0:
# cache old internal state
self.cache_internal_state(encoded_old_key, state_obj)
state_obj.namespace = None
state_obj._internal_state = None
self._current_key = key
self._encoded_current_key = self._key_coder_impl.encode(self._current_key)
def get_current_key(self):
return self._current_key
def commit(self):
for internal_state in self._internal_state_cache:
self.commit_internal_state(internal_state)
for name, state in self._all_states.items():
if (name, self._encoded_current_key, self._encode_namespace(state.namespace)) \
not in self._internal_state_cache:
self.commit_internal_state(state._internal_state)
def clear_cached_iterators(self):
if self._map_state_handler.get_cached_iterators_num() > 0:
self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key
self._map_state_handler.clear(self._clear_iterator_mark)
def merge_namespaces(self, state: SynchronousMergingRuntimeState, target, sources, ttl_config):
for source in sources:
state.set_current_namespace(source)
self.commit_internal_state(state.get_internal_state())
state.set_current_namespace(target)
self.commit_internal_state(state.get_internal_state())
encoded_target_namespace = self._encode_namespace(target)
encoded_namespaces = [self._encode_namespace(source) for source in sources]
self.clear_state_cache(state, [encoded_target_namespace] + encoded_namespaces)
state_key = self.get_bag_state_key(
state.name, self._encoded_current_key, encoded_target_namespace, ttl_config)
state_key.bag_user_state.transform_id = self.MERGE_NAMESAPCES_MARK
encoded_namespaces_writer = BytesIO()
encoded_namespaces_writer.write(len(sources).to_bytes(4, 'big'))
for encoded_namespace in encoded_namespaces:
encoded_namespaces_writer.write(encoded_namespace)
sources_bytes = encoded_namespaces_writer.getvalue()
to_await = self._map_state_handler._underlying.append_raw(state_key, sources_bytes)
if to_await:
to_await.get()
def clear_state_cache(self, state: SynchronousMergingRuntimeState, encoded_namespaces):
name = state.name
for encoded_namespace in encoded_namespaces:
if (name, self._encoded_current_key, encoded_namespace) in self._internal_state_cache:
# commit and clear the write cache
self._internal_state_cache.evict(
(name, self._encoded_current_key, encoded_namespace))
# currently all the SynchronousMergingRuntimeState is based on bag state
state_key = self.get_bag_state_key(
name, self._encoded_current_key, encoded_namespace, None)
# clear the read cache, the read cache is shared between map state handler and bag
# state handler. So we can use the map state handler instead.
self._map_state_handler.clear_read_cache(state_key)
def get_bag_state_key(self, name, encoded_key, encoded_namespace, ttl_config):
from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor
state_proto = StateDescriptor()
state_proto.state_name = name
if ttl_config is not None:
state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto())
return beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
transform_id="",
window=encoded_namespace,
user_state_id=base64.b64encode(state_proto.SerializeToString()),
key=encoded_key))
@staticmethod
def commit_internal_state(internal_state):
if internal_state is not None:
internal_state.commit()
# reset the status of the internal state to reuse the object cross bundle
if isinstance(internal_state, SynchronousBagRuntimeState):
internal_state._cleared = False
internal_state._added_elements = []
class SynchronousReadOnlyBroadcastRuntimeState(InternalReadOnlyBroadcastState):
def __init__(self, name: str, internal_map_state: "InternalSynchronousMapRuntimeState"):
self._name = name
self._internal_map_state = internal_map_state
def get(self, key):
return self._internal_map_state.get(key)
def contains(self, key) -> bool:
return self._internal_map_state.contains(key)
def items(self):
return self._internal_map_state.items()
def keys(self):
return self._internal_map_state.keys()
def values(self):
return self._internal_map_state.values()
def is_empty(self):
return self._internal_map_state.is_empty()
def clear(self):
return self._internal_map_state.clear()
class SynchronousBroadcastRuntimeState(
SynchronousReadOnlyBroadcastRuntimeState, InternalBroadcastState
):
def __init__(self, name: str, internal_map_state: "InternalSynchronousMapRuntimeState"):
super(SynchronousBroadcastRuntimeState, self).__init__(name, internal_map_state)
def put(self, key, value):
self._internal_map_state.put(key, value)
def put_all(self, dict_value):
self._internal_map_state.put_all(dict_value)
def remove(self, key):
self._internal_map_state.remove(key)
def commit(self):
self._internal_map_state.commit()
def to_read_only_broadcast_state(self) -> "SynchronousReadOnlyBroadcastRuntimeState":
return SynchronousReadOnlyBroadcastRuntimeState(self._name, self._internal_map_state)
class OperatorStateBackend(OperatorStateStore, ABC):
@abstractmethod
def commit(self):
pass
class RemoteOperatorStateBackend(OperatorStateBackend):
def __init__(
self, state_handler, state_cache_size, map_state_read_cache_size, map_state_write_cache_size
):
self._state_handler = state_handler
self._state_cache_size = state_cache_size
# NOTE: if user stores a state into a class member, that state actually won't be actually
# evicted from memory (because its counter > 0)
self._state_cache = LRUCache(state_cache_size, None)
self._state_cache.set_on_evict(lambda _, state: state.commit())
self._map_state_read_cache_size = map_state_read_cache_size
self._map_state_write_cache_size = map_state_write_cache_size
self._map_state_handler = CachingMapStateHandler(state_handler, map_state_read_cache_size)
def get_broadcast_state(
self, state_descriptor: MapStateDescriptor
) -> 'SynchronousBroadcastRuntimeState':
state_name = state_descriptor.name
map_coder = cast(MapCoder, from_type_info(state_descriptor.type_info)) # type: MapCoder
key_coder = map_coder._key_coder
value_coder = map_coder._value_coder
if state_name in self._state_cache:
self._validate_broadcast_state(state_name, key_coder, value_coder)
return self._state_cache.get(state_name)
state_proto = pb2_StateDescriptor()
state_proto.state_name = state_name
# Currently, MultimapKeysSideInput is used for BroadcastState
state_key = beam_fn_api_pb2.StateKey(
multimap_keys_side_input=beam_fn_api_pb2.StateKey.MultimapKeysSideInput(
transform_id="",
window=bytes(),
side_input_id=base64.b64encode(state_proto.SerializeToString()),
)
)
internal_map_state = InternalSynchronousMapRuntimeState(
self._map_state_handler,
state_key,
key_coder,
value_coder,
self._map_state_write_cache_size,
)
broadcast_state = SynchronousBroadcastRuntimeState(
state_descriptor.name, internal_map_state
)
self._state_cache.put(state_name, broadcast_state)
return broadcast_state
def commit(self):
for state in self._state_cache:
cast(SynchronousBroadcastRuntimeState, state).commit()
def _validate_broadcast_state(self, name, key_coder, value_coder):
if name in self._state_cache:
state = cast(SynchronousBroadcastRuntimeState, self._state_cache.get(name))
if (
key_coder != state._internal_map_state._map_key_coder
or value_coder != state._internal_map_state._map_value_coder
):
raise Exception("State name corrupted: %s" % name)