Skip to content

Commit

Permalink
Merge pull request #29695 Enable keys values multimap protocol based …
Browse files Browse the repository at this point in the history
…on runner capabilities.
  • Loading branch information
robertwb authored Dec 11, 2023
2 parents f934230 + 5c879cc commit bdaec7a
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,11 @@ message StandardRunnerProtocols {
// https://s.apache.org/beam-fn-api-control-data-embedding
CONTROL_RESPONSE_ELEMENTS_EMBEDDING = 6
[(beam_urn) = "beam:protocol:control_response_elements_embedding:v1"];

// Indicates that this runner can handle the multimap_keys_values_side_input
// style read of a multimap side input.
MULTIMAP_KEYS_VALUES_SIDE_INPUT = 7
[(beam_urn) = "beam:protocol:multimap_keys_values_side_input:v1"];
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
Expand Down Expand Up @@ -168,6 +169,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT> runner =
new FnApiDoFnRunner<>(
context.getPipelineOptions(),
context.getRunnerCapabilities(),
context.getShortIdMap(),
context.getBeamFnStateClient(),
context.getPTransformId(),
Expand Down Expand Up @@ -336,6 +338,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,

FnApiDoFnRunner(
PipelineOptions pipelineOptions,
Set<String> runnerCapabilities,
ShortIdMap shortIds,
BeamFnStateClient beamFnStateClient,
String pTransformId,
Expand Down Expand Up @@ -740,6 +743,7 @@ private ByteString encodeProgress(double value) throws IOException {
this.stateAccessor =
new FnApiStateAccessor(
pipelineOptions,
runnerCapabilities,
pTransformId,
processBundleInstructionId,
cacheTokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.construction.BeamUrns;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VoidCoder;
Expand Down Expand Up @@ -74,6 +77,7 @@
})
public class FnApiStateAccessor<K> implements SideInputReader, StateBinder {
private final PipelineOptions pipelineOptions;
private final Set<String> runnerCapabilites;
private final Map<StateKey, Object> stateKeyObjectCache;
private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
private final BeamFnStateClient beamFnStateClient;
Expand All @@ -91,6 +95,7 @@ public class FnApiStateAccessor<K> implements SideInputReader, StateBinder {

public FnApiStateAccessor(
PipelineOptions pipelineOptions,
Set<String> runnerCapabilites,
String ptransformId,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
Expand All @@ -103,6 +108,7 @@ public FnApiStateAccessor(
Supplier<K> currentKeySupplier,
Supplier<BoundedWindow> currentWindowSupplier) {
this.pipelineOptions = pipelineOptions;
this.runnerCapabilites = runnerCapabilites;
this.stateKeyObjectCache = Maps.newHashMap();
this.sideInputSpecMap = sideInputSpecMap;
this.beamFnStateClient = beamFnStateClient;
Expand Down Expand Up @@ -238,7 +244,11 @@ public ResultT get() {
processBundleInstructionId.get(),
key,
((KvCoder) sideInputSpec.getCoder()).getKeyCoder(),
((KvCoder) sideInputSpec.getCoder()).getValueCoder()));
((KvCoder) sideInputSpec.getCoder()).getValueCoder(),
runnerCapabilites.contains(
BeamUrns.getUrn(
RunnerApi.StandardRunnerProtocols.Enum
.MULTIMAP_KEYS_VALUES_SIDE_INPUT))));
default:
throw new IllegalStateException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,6 @@ public class MultimapSideInput<K, V> implements MultimapView<K, V> {
private volatile Function<ByteString, Iterable<V>> bulkReadResult;
private final boolean useBulkRead;

public MultimapSideInput(
Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
StateKey stateKey,
Coder<K> keyCoder,
Coder<V> valueCoder) {
// TODO(robertwb): Plumb the value of useBulkRead from runner capabilities.
this(cache, beamFnStateClient, instructionId, stateKey, keyCoder, valueCoder, false);
}

public MultimapSideInput(
Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/portability/common_urns.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
StandardPTransforms = beam_runner_api_pb2_urns.StandardPTransforms
StandardRequirements = beam_runner_api_pb2_urns.StandardRequirements
StandardResourceHints = beam_runner_api_pb2_urns.StandardResourceHints
StandardRunnerProtocols = beam_runner_api_pb2_urns.StandardRunnerProtocols
StandardSideInputTypes = beam_runner_api_pb2_urns.StandardSideInputTypes
StandardUserStateTypes = beam_runner_api_pb2_urns.StandardUserStateTypes
ExpansionMethods = external_transforms_pb2_urns.ExpansionMethods
Expand Down Expand Up @@ -73,6 +74,7 @@
monitoring_info_labels = MonitoringInfo.MonitoringInfoLabels

protocols = StandardProtocols.Enum
runner_protocols = StandardRunnerProtocols.Enum
requirements = StandardRequirements.Enum

displayData = StandardDisplayData.DisplayData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
# Time-based flush is enabled in the fn_api_runner by default.
DATA_BUFFER_TIME_LIMIT_MS = 1000

FNAPI_RUNNER_CAPABILITIES = frozenset([
common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn,
])

_LOGGER = logging.getLogger(__name__)

T = TypeVar('T')
Expand Down Expand Up @@ -363,6 +367,7 @@ def __init__(self,
self.data_conn = self.data_plane_handler
state_cache = StateCache(STATE_CACHE_SIZE_MB * MB_TO_BYTES)
self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
FNAPI_RUNNER_CAPABILITIES,
SingletonStateHandlerFactory(
sdk_worker.GlobalCachingStateHandler(state_cache, state)),
data_plane.InMemoryDataChannelFactory(
Expand Down Expand Up @@ -433,6 +438,7 @@ def GetProvisionInfo(self, request, context=None):
info.control_endpoint.CopyFrom(worker.control_api_service_descriptor())
else:
info = self._base_info
info.runner_capabilities[:] = FNAPI_RUNNER_CAPABILITIES
return beam_provision_api_pb2.GetProvisionInfoResponse(info=info)


Expand Down Expand Up @@ -663,7 +669,8 @@ def start_worker(self):
self.control_address,
state_cache_size=self._state_cache_size,
data_buffer_time_limit_ms=self._data_buffer_time_limit_ms,
worker_id=self.worker_id)
worker_id=self.worker_id,
runner_capabilities=FNAPI_RUNNER_CAPABILITIES)
self.worker_thread = threading.Thread(
name='run_worker', target=self.worker.run)
self.worker_thread.daemon = True
Expand Down
17 changes: 14 additions & 3 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@ class BundleProcessor(object):
""" A class for processing bundles of elements. """

def __init__(self,
runner_capabilities, # type: FrozenSet[str]
process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
state_handler, # type: sdk_worker.CachingStateHandler
data_channel_factory, # type: data_plane.DataChannelFactory
Expand All @@ -930,11 +931,14 @@ def __init__(self,
"""Initialize a bundle processor.
Args:
runner_capabilities (``FrozenSet[str]``): The set of capabilities of the
runner with which we will be interacting
process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``):
a description of the stage that this ``BundleProcessor``is to execute.
state_handler (CachingStateHandler).
data_channel_factory (``data_plane.DataChannelFactory``).
"""
self.runner_capabilities = runner_capabilities
self.process_bundle_descriptor = process_bundle_descriptor
self.state_handler = state_handler
self.data_channel_factory = data_channel_factory
Expand Down Expand Up @@ -976,12 +980,14 @@ def create_execution_tree(
):
# type: (...) -> collections.OrderedDict[str, operations.DoOperation]
transform_factory = BeamTransformFactory(
self.runner_capabilities,
descriptor,
self.data_channel_factory,
self.counter_factory,
self.state_sampler,
self.state_handler,
self.data_sampler)
self.data_sampler,
)

self.timers_info = transform_factory.extract_timers_info()

Expand Down Expand Up @@ -1267,13 +1273,15 @@ class ExecutionContext:
class BeamTransformFactory(object):
"""Factory for turning transform_protos into executable operations."""
def __init__(self,
runner_capabilities, # type: FrozenSet[str]
descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
data_channel_factory, # type: data_plane.DataChannelFactory
counter_factory, # type: counters.CounterFactory
state_sampler, # type: statesampler.StateSampler
state_handler, # type: sdk_worker.CachingStateHandler
data_sampler, # type: Optional[data_sampler.DataSampler]
):
self.runner_capabilities = runner_capabilities
self.descriptor = descriptor
self.data_channel_factory = data_channel_factory
self.counter_factory = counter_factory
Expand Down Expand Up @@ -1699,8 +1707,11 @@ def _create_pardo_operation(
transform_id,
tag,
si,
input_tags_to_coders[tag]) for tag,
si in tagged_side_inputs
input_tags_to_coders[tag],
use_bulk_read=(
common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn
in factory.runner_capabilities))
for (tag, si) in tagged_side_inputs
]
else:
side_input_maps = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_disabled_by_default(self):
"""
descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
descriptor.pcollections['a'].unique_name = 'a'
_ = BundleProcessor(descriptor, None, None)
_ = BundleProcessor(set(), descriptor, None, None)
self.assertEqual(len(descriptor.transforms), 0)

def test_can_sample(self):
Expand Down Expand Up @@ -301,7 +301,7 @@ def test_can_sample(self):
# Create and process a fake bundle. The instruction id doesn't matter
# here.
processor = BundleProcessor(
descriptor, None, None, data_sampler=data_sampler)
set(), descriptor, None, None, data_sampler=data_sampler)
processor.process_bundle('instruction_id')

samples = data_sampler.wait_for_samples([PCOLLECTION_ID])
Expand Down Expand Up @@ -377,7 +377,7 @@ def test_can_sample_exceptions(self):
# Create and process a fake bundle. The instruction id doesn't matter
# here.
processor = BundleProcessor(
descriptor, None, None, data_sampler=data_sampler)
set(), descriptor, None, None, data_sampler=data_sampler)

with self.assertRaisesRegex(RuntimeError, 'expected exception'):
processor.process_bundle('instruction_id')
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/runners/worker/log_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_extracts_transform_id_during_exceptions(self):

# Create and process a fake bundle. The instruction id doesn't matter
# here.
processor = BundleProcessor(descriptor, None, None)
processor = BundleProcessor(set(), descriptor, None, None)

with self.assertRaisesRegex(RuntimeError, 'expected exception'):
processor.process_bundle('instruction_id')
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
# Unrecoverable SDK harness initialization error (if any)
# that should be reported to the runner when proocessing the first bundle.
deferred_exception=None, # type: Optional[Exception]
runner_capabilities=frozenset(), # type: FrozenSet[str]
):
# type: (...) -> None
self._alive = True
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
self._state_cache, credentials)
self._profiler_factory = profiler_factory
self.data_sampler = data_sampler
self.runner_capabilities = runner_capabilities

def default_factory(id):
# type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor
Expand All @@ -212,6 +214,7 @@ def default_factory(id):
self._fns = KeyedDefaultDict(default_factory)
# BundleProcessor cache across all workers.
self._bundle_processor_cache = BundleProcessorCache(
self.runner_capabilities,
state_handler_factory=self._state_handler_factory,
data_channel_factory=self._data_channel_factory,
fns=self._fns,
Expand Down Expand Up @@ -419,12 +422,14 @@ class BundleProcessorCache(object):

def __init__(
self,
runner_capabilities, # type: FrozenSet[str]
state_handler_factory, # type: StateHandlerFactory
data_channel_factory, # type: data_plane.DataChannelFactory
fns, # type: MutableMapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
data_sampler=None, # type: Optional[data_sampler.DataSampler]
):
# type: (...) -> None
self.runner_capabilities = runner_capabilities
self.fns = fns
self.state_handler_factory = state_handler_factory
self.data_channel_factory = data_channel_factory
Expand Down Expand Up @@ -485,6 +490,7 @@ def get(self, instruction_id, bundle_descriptor_id):

# Make sure we instantiate the processor while not holding the lock.
processor = bundle_processor.BundleProcessor(
self.runner_capabilities,
self.fns[bundle_descriptor_id],
self.state_handler_factory.create_state_handler(
self.fns[bundle_descriptor_id].state_api_service_descriptor),
Expand Down
10 changes: 5 additions & 5 deletions sdks/python/apache_beam/runners/worker/sdk_worker_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def create_harness(environment, dry_run=False):
pickle_library = sdk_pipeline_options.view_as(SetupOptions).pickle_library
pickler.set_library(pickle_library)

if 'SEMI_PERSISTENT_DIRECTORY' in environment:
semi_persistent_directory = environment['SEMI_PERSISTENT_DIRECTORY']
else:
semi_persistent_directory = None
semi_persistent_directory = environment.get('SEMI_PERSISTENT_DIRECTORY', None)
runner_capabilities = frozenset(
environment.get('RUNNER_CAPABILITIES', '').split())

_LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
_worker_id = environment.get('WORKER_ID', None)
Expand Down Expand Up @@ -167,7 +166,8 @@ def create_harness(environment, dry_run=False):
sdk_pipeline_options.view_as(ProfilingOptions)),
enable_heap_dump=enable_heap_dump,
data_sampler=data_sampler,
deferred_exception=deferred_exception)
deferred_exception=deferred_exception,
runner_capabilities=runner_capabilities)
return fn_log_handler, sdk_harness, sdk_pipeline_options


Expand Down
8 changes: 4 additions & 4 deletions sdks/python/apache_beam/runners/worker/sdk_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_fn_registration(self):

def test_inactive_bundle_processor_returns_empty_progress_response(self):
bundle_processor = mock.MagicMock()
bundle_processor_cache = BundleProcessorCache(None, None, {})
bundle_processor_cache = BundleProcessorCache(None, None, None, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
split_request = beam_fn_api_pb2.InstructionRequest(
Expand All @@ -153,7 +153,7 @@ def test_inactive_bundle_processor_returns_empty_progress_response(self):

def test_failed_bundle_processor_returns_failed_progress_response(self):
bundle_processor = mock.MagicMock()
bundle_processor_cache = BundleProcessorCache(None, None, {})
bundle_processor_cache = BundleProcessorCache(None, None, None, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)

Expand All @@ -172,7 +172,7 @@ def test_failed_bundle_processor_returns_failed_progress_response(self):

def test_inactive_bundle_processor_returns_empty_split_response(self):
bundle_processor = mock.MagicMock()
bundle_processor_cache = BundleProcessorCache(None, None, {})
bundle_processor_cache = BundleProcessorCache(None, None, None, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
split_request = beam_fn_api_pb2.InstructionRequest(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_harness_monitoring_infos_and_metadata(self):

def test_failed_bundle_processor_returns_failed_split_response(self):
bundle_processor = mock.MagicMock()
bundle_processor_cache = BundleProcessorCache(None, None, {})
bundle_processor_cache = BundleProcessorCache(None, None, None, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)

Expand Down
1 change: 0 additions & 1 deletion sdks/python/gen_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def generate_proto_files(force=False):
generate_init_files_lite(PYTHON_OUTPUT_PATH)
for proto_package in proto_packages:
generate_urn_files(proto_package, PYTHON_OUTPUT_PATH)

generate_init_files_full(PYTHON_OUTPUT_PATH)


Expand Down

0 comments on commit bdaec7a

Please sign in to comment.