diff --git a/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py new file mode 100644 index 000000000000..11ec02fbd54f --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A simple example demonstrating usage of the EnvoyRateLimiter with Vertex AI. +""" + +import argparse +import logging + +import apache_beam as beam +from apache_beam.io.components.rate_limiter import EnvoyRateLimiter +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + + +def run(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--project', + dest='project', + help='The Google Cloud project ID for Vertex AI.') + parser.add_argument( + '--location', + dest='location', + help='The Google Cloud location (e.g. us-central1) for Vertex AI.') + parser.add_argument( + '--endpoint_id', + dest='endpoint_id', + help='The ID of the Vertex AI endpoint.') + parser.add_argument( + '--rls_address', + dest='rls_address', + help='The address of the Envoy Rate Limit Service (e.g. localhost:8081).') + + known_args, pipeline_args = parser.parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = True + + # Initialize the EnvoyRateLimiter + rate_limiter = EnvoyRateLimiter( + service_address=known_args.rls_address, + domain="mongo_cps", + descriptors=[{ + "database": "users" + }], + namespace='example_pipeline') + + # Initialize the VertexAIModelHandler with the rate limiter + model_handler = VertexAIModelHandlerJSON( + endpoint_id=known_args.endpoint_id, + project=known_args.project, + location=known_args.location, + rate_limiter=rate_limiter) + + # Input features for the model + features = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]] + + with beam.Pipeline(options=pipeline_options) as p: + _ = ( + p + | 'CreateInputs' >> beam.Create(features) + | 'RunInference' >> RunInference(model_handler) + | 'PrintPredictions' >> beam.Map(logging.info)) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py b/sdks/python/apache_beam/io/components/rate_limiter.py index 3de39ddd935b..5c3b36e8ab0a 100644 --- a/sdks/python/apache_beam/io/components/rate_limiter.py +++ b/sdks/python/apache_beam/io/components/rate_limiter.py @@ -62,7 +62,12 @@ def __init__(self, namespace: str = ""): @abc.abstractmethod def throttle(self, **kwargs) -> bool: - """Check if request should be throttled. + """Applies rate limiting to the request. + + This method checks if the request is permitted by the rate limiting policy. + Depending on the implementation and configuration, it may block (sleep) + until the request is allowed, or return false if the rate limit retry is + exceeded. Args: **kwargs: Keyword arguments specific to the RateLimiter implementation. @@ -78,8 +83,12 @@ def throttle(self, **kwargs) -> bool: class EnvoyRateLimiter(RateLimiter): - """ - Rate limiter implementation that uses an external Envoy Rate Limit Service. + """Rate limiter implementation that uses an external Envoy Rate Limit Service. + + This limiter connects to a gRPC Envoy Rate Limit Service (RLS) to determine + whether a request should be allowed. It supports defining a domain and a + list of descriptors that correspond to the rate limit configuration in the + RLS. """ def __init__( self, @@ -89,7 +98,7 @@ def __init__( timeout: float = 5.0, block_until_allowed: bool = True, retries: int = 3, - namespace: str = ""): + namespace: str = ''): """ Args: service_address: Address of the Envoy RLS (e.g., 'localhost:8081'). @@ -140,7 +149,15 @@ def init_connection(self): self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel) def throttle(self, hits_added: int = 1) -> bool: - """Calls the Envoy RLS to check for rate limits. + """Calls the Envoy RLS to apply rate limits. + + Sends a rate limit request to the configured Envoy Rate Limit Service. + If 'block_until_allowed' is True, this method will sleep and retry + if the limit is exceeded, effectively blocking until the request is + permitted. + + If 'block_until_allowed' is False, it will return False after the retry + limit is exceeded. Args: hits_added: Number of hits to add to the rate limit. @@ -224,3 +241,16 @@ def throttle(self, hits_added: int = 1) -> bool: response.overall_code) break return throttled + + def __getstate__(self): + state = self.__dict__.copy() + if '_lock' in state: + del state['_lock'] + if '_stub' in state: + del state['_stub'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._lock = threading.Lock() + self._stub = None diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index d79565ee24da..ada7cb3237d4 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -56,6 +56,7 @@ import apache_beam as beam from apache_beam.io.components.adaptive_throttler import ReactiveThrottler +from apache_beam.io.components.rate_limiter import RateLimiter from apache_beam.utils import multi_process_shared from apache_beam.utils import retry from apache_beam.utils import shared @@ -102,6 +103,11 @@ def __new__(cls, example, inference, model_id=None): PredictionResult.model_id.__doc__ = """Model ID used to run the prediction.""" +class RateLimitExceeded(RuntimeError): + """RateLimit Exceeded to process a batch of requests.""" + pass + + class ModelMetadata(NamedTuple): model_id: str model_name: str @@ -349,7 +355,8 @@ def __init__( *, window_ms: int = 1 * _MILLISECOND_TO_SECOND, bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, - overload_ratio: float = 2): + overload_ratio: float = 2, + rate_limiter: Optional[RateLimiter] = None): """Initializes a ReactiveThrottler class for enabling client-side throttling for remote calls to an inference service. Also wraps provided calls to the service with retry logic. @@ -372,6 +379,7 @@ def __init__( overload_ratio: the target ratio between requests sent and successful requests. This is "K" in the formula in https://landing.google.com/sre/book/chapters/handling-overload.html. + rate_limiter: A RateLimiter object for setting a global rate limit. """ # Configure ReactiveThrottler for client-side throttling behavior. self.throttler = ReactiveThrottler( @@ -383,6 +391,9 @@ def __init__( self.logger = logging.getLogger(namespace) self.num_retries = num_retries self.retry_filter = retry_filter + self._rate_limiter = rate_limiter + self._shared_rate_limiter = None + self._shared_handle = shared.Shared() def __init_subclass__(cls): if cls.load_model is not RemoteModelHandler.load_model: @@ -431,6 +442,19 @@ def run_inference( Returns: An Iterable of Predictions. """ + if self._rate_limiter: + if self._shared_rate_limiter is None: + + def init_limiter(): + return self._rate_limiter + + self._shared_rate_limiter = self._shared_handle.acquire(init_limiter) + + if not self._shared_rate_limiter.throttle(hits_added=len(batch)): + raise RateLimitExceeded( + "Rate Limit Exceeded, " + "Could not process this batch.") + self.throttler.throttle() try: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 574e71de89ce..e6865a13ef8f 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2071,6 +2071,67 @@ def run_inference(self, responses.append(model.predict(example)) return responses + def test_run_inference_with_rate_limiter(self): + class FakeRateLimiter(base.RateLimiter): + def __init__(self): + super().__init__(namespace='test_namespace') + + def throttle(self, hits_added=1): + self.requests_counter.inc() + return True + + limiter = FakeRateLimiter() + + with TestPipeline() as pipeline: + examples = [1, 5] + + class ConcreteRemoteModelHandler(base.RemoteModelHandler): + def create_client(self): + return FakeModel() + + def request(self, batch, model, inference_args=None): + return [model.predict(example) for example in batch] + + model_handler = ConcreteRemoteModelHandler( + rate_limiter=limiter, namespace='test_namespace') + + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference(model_handler) + + expected = [2, 6] + assert_that(actual, equal_to(expected)) + + result = pipeline.run() + result.wait_until_finish() + + metrics_filter = MetricsFilter().with_name( + 'RatelimitRequestsTotal').with_namespace('test_namespace') + metrics = result.metrics().query(metrics_filter) + self.assertGreaterEqual(metrics['counters'][0].committed, 0) + + def test_run_inference_with_rate_limiter_exceeded(self): + class FakeRateLimiter(base.RateLimiter): + def __init__(self): + super().__init__(namespace='test_namespace') + + def throttle(self, hits_added=1): + return False + + class ConcreteRemoteModelHandler(base.RemoteModelHandler): + def create_client(self): + return FakeModel() + + def request(self, batch, model, inference_args=None): + return [model.predict(example) for example in batch] + + model_handler = ConcreteRemoteModelHandler( + rate_limiter=FakeRateLimiter(), + namespace='test_namespace', + num_retries=0) + + with self.assertRaises(base.RateLimitExceeded): + model_handler.run_inference([1], FakeModel()) + if __name__ == '__main__': unittest.main()