diff --git a/.gitignore b/.gitignore index 6d83cfc..0a96db7 100755 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ MANIFEST .idea pydgraph.iml +# VS Code +.vscode + # Python Virtual Environments venv .venv diff --git a/README.md b/README.md index 2fbaaf6..97e489a 100644 --- a/README.md +++ b/README.md @@ -295,6 +295,54 @@ request = txn.create_request(mutations=[mutation], commit_now=True) txn.do_request(request) ``` +### Committing a Transaction + +A transaction can be committed using the `Txn#commit()` method. If your transaction +consist solely of `Txn#query` or `Txn#queryWithVars` calls, and no calls to +`Txn#mutate`, then calling `Txn#commit()` is not necessary. + +An error is raised if another transaction(s) modify the same data concurrently that was +modified in the current transaction. It is up to the user to retry transactions +when they fail. + +```python +txn = client.txn() +try: + # ... + # Perform any number of queries and mutations + # ... + # and finally... + txn.commit() +except pydgraph.AbortedError: + # Retry or handle exception. +finally: + # Clean up. Calling this after txn.commit() is a no-op + # and hence safe. + txn.discard() +``` + +#### Using Transaction with Context Manager + +The Python context manager will automatically perform the "`commit`" action +after all queries and mutations have been done, and perform "`discard`" action +to clean the transaction. +When something goes wrong in the scope of context manager, "`commit`" will not +be called,and the "`discard`" action will be called to drop any potential changes. + +```python +with client.begin(read_only=False, best_effort=False) as txn: + # Do some queries or mutations here +``` + +or you can directly create a transaction from the `Txn` class. + +```python +with pydgraph.Txn(client, read_only=False, best_effort=False) as txn: + # Do some queries or mutations here +``` + +> `client.begin()` can only be used with "`with-as`" blocks, while `pydgraph.Txn` class can be directly called to instantiate a transaction object. + ### Running a Query You can run a query by calling `Txn#query(string)`. You will need to pass in a @@ -453,6 +501,28 @@ stub1.close() stub2.close() ``` +#### Use context manager to automatically clean resources + +Use function call: + +```python +with pydgraph.client_stub(SERVER_ADDR) as stub1: + with pydgraph.client_stub(SERVER_ADDR) as stub2: + client = pydgraph.DgraphClient(stub1, stub2) +``` + +Use class constructor: + +```python +with pydgraph.DgraphClientStub(SERVER_ADDR) as stub1: + with pydgraph.DgraphClientStub(SERVER_ADDR) as stub2: + client = pydgraph.DgraphClient(stub1, stub2) +``` + +Note: `client` should be used inside the "`with-as`" block. The resources related to +`client` will be automatically released outside the block and `client` is not usable +any more. + ### Setting Metadata Headers Metadata headers such as authentication tokens can be set through the metadata of gRPC methods. diff --git a/pydgraph/client.py b/pydgraph/client.py index c6fe9bb..56eb752 100755 --- a/pydgraph/client.py +++ b/pydgraph/client.py @@ -3,6 +3,7 @@ """Dgraph python client.""" +import contextlib import random from pydgraph import errors, txn, util @@ -151,9 +152,9 @@ def handle_alter_future(future): except Exception as error: DgraphClient._common_except_alter(error) - def txn(self, read_only=False, best_effort=False): + def txn(self, read_only=False, best_effort=False, **commit_kwargs): """Creates a transaction.""" - return txn.Txn(self, read_only=read_only, best_effort=best_effort) + return txn.Txn(self, read_only=read_only, best_effort=best_effort, **commit_kwargs) def any_client(self): """Returns a random gRPC client so that requests are distributed evenly among them.""" @@ -165,3 +166,23 @@ def add_login_metadata(self, metadata): return new_metadata new_metadata.extend(metadata) return new_metadata + + @contextlib.contextmanager + def begin(self, + read_only:bool=False, best_effort:bool=False, + timeout = None, metadata = None, credentials = None): + '''Start a managed transaction. + + Note + ---- + Only use this function in ``with-as`` blocks. + ''' + tx = self.txn(read_only=read_only, best_effort=best_effort) + try: + yield tx + if read_only == False and tx._finished == False: + tx.commit(timeout=timeout, metadata=metadata, credentials=credentials) + except Exception as e: + raise e + finally: + tx.discard() \ No newline at end of file diff --git a/pydgraph/client_stub.py b/pydgraph/client_stub.py index da432db..4137adb 100644 --- a/pydgraph/client_stub.py +++ b/pydgraph/client_stub.py @@ -3,6 +3,7 @@ """Stub for RPC request.""" +import contextlib import grpc from pydgraph.meta import VERSION @@ -29,6 +30,14 @@ def __init__(self, addr="localhost:9080", credentials=None, options=None): self.channel = grpc.secure_channel(addr, credentials, options) self.stub = api_grpc.DgraphStub(self.channel) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + if exc_type is not None: + raise exc_val def login(self, login_req, timeout=None, metadata=None, credentials=None): return self.stub.Login( @@ -118,3 +127,27 @@ def from_cloud(cloud_endpoint, api_key, options=None): options=options, ) return client_stub + +@contextlib.contextmanager +def client_stub(addr='localhost:9080', **kwargs): + """ Create a managed DgraphClientStub instance. + + Parameters + ---------- + addr : str, optional + credentials : ChannelCredentials, optional + options: List[Dict] + An optional list of key-value pairs (``channel_arguments`` + in gRPC Core runtime) to configure the channel. + + Note + ---- + Only use this function in ``with-as`` blocks. + """ + stub = DgraphClientStub(addr=addr, **kwargs) + try: + yield stub + except Exception as e: + raise e + finally: + stub.close() \ No newline at end of file diff --git a/pydgraph/txn.py b/pydgraph/txn.py index c5186ba..aca0912 100644 --- a/pydgraph/txn.py +++ b/pydgraph/txn.py @@ -30,7 +30,8 @@ class Txn(object): after calling commit. """ - def __init__(self, client, read_only=False, best_effort=False): + def __init__(self, client, read_only=False, best_effort=False, + timeout=None, metadata=None, credentials=None): if not read_only and best_effort: raise Exception( "Best effort transactions are only compatible with " @@ -45,6 +46,23 @@ def __init__(self, client, read_only=False, best_effort=False): self._mutated = False self._read_only = read_only self._best_effort = best_effort + self._commit_kwargs = { + "timeout": timeout, + "metadata": metadata, + "credentials": credentials + } + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.discard(**self._commit_kwargs) + raise exc_val + if self._read_only == False and self._finished == False: + self.commit(**self._commit_kwargs) + else: + self.discard(**self._commit_kwargs) def query( self, @@ -201,7 +219,7 @@ def handle_query_future(future): try: response = future.result() except Exception as error: - txn._common_except_mutate(error) + Txn._common_except_mutate(error) return response @@ -212,11 +230,11 @@ def handle_mutate_future(txn, future, commit_now): response = future.result() except Exception as error: try: - txn.discard(timeout=timeout, metadata=metadata, credentials=credentials) + txn.discard(**txn._commit_kwargs) except: # Ignore error - user should see the original error. pass - txn._common_except_mutate(error) + Txn._common_except_mutate(error) if commit_now: txn._finished = True diff --git a/tests/test_acct_upsert.py b/tests/test_acct_upsert.py index 49fca46..1212a5c 100644 --- a/tests/test_acct_upsert.py +++ b/tests/test_acct_upsert.py @@ -15,7 +15,7 @@ import pydgraph -from . import helper +from tests import helper CONCURRENCY = 5 FIRSTS = ["Paul", "Eric", "Jack", "John", "Martin"] diff --git a/tests/test_async.py b/tests/test_async.py index 645a76a..027c1e4 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -10,7 +10,7 @@ import pydgraph -from . import helper +from tests import helper class TestAsync(helper.ClientIntegrationTestCase): diff --git a/tests/test_client_stub.py b/tests/test_client_stub.py index 94abe5f..d3d1e1f 100644 --- a/tests/test_client_stub.py +++ b/tests/test_client_stub.py @@ -10,10 +10,8 @@ import unittest import pydgraph - from . import helper - class TestDgraphClientStub(helper.ClientIntegrationTestCase): """Tests client stub.""" @@ -70,10 +68,31 @@ def test_from_cloud(self): raise (e) +class TestDgraphClientStubContextManager(helper.ClientIntegrationTestCase): + def setUp(self): + pass + + def test_context_manager(self): + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + def test_context_manager_code_exception(self): + with self.assertRaises(AttributeError): + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + self.check_version(client_stub) # AttributeError: no such method + + def test_context_manager_function_wrapper(self): + with pydgraph.client_stub(addr=self.TEST_SERVER_ADDR) as client_stub: + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + def suite(): """Returns a test suite object.""" suite_obj = unittest.TestSuite() suite_obj.addTest(TestDgraphClientStub()) + suite_obj.addTest(TestDgraphClientStubContextManager()) return suite_obj diff --git a/tests/test_essentials.py b/tests/test_essentials.py index 2ee5f2a..6f06b5c 100644 --- a/tests/test_essentials.py +++ b/tests/test_essentials.py @@ -10,7 +10,7 @@ import logging import unittest -from . import helper +from tests import helper class TestEssentials(helper.ClientIntegrationTestCase): diff --git a/tests/test_queries.py b/tests/test_queries.py index d8684c9..80781ce 100755 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -13,7 +13,7 @@ import pydgraph -from . import helper +from tests import helper class TestQueries(helper.ClientIntegrationTestCase): diff --git a/tests/test_txn.py b/tests/test_txn.py index bff027b..237bc3e 100644 --- a/tests/test_txn.py +++ b/tests/test_txn.py @@ -11,7 +11,7 @@ import pydgraph -from . import helper +from tests import helper class TestTxn(helper.ClientIntegrationTestCase): @@ -610,10 +610,40 @@ def test_sp_star2(self): self.assertEqual([{"uid": uid1}], json.loads(resp.json).get("me")) +class TestContextManager(helper.ClientIntegrationTestCase): + def setUp(self): + self.stub = pydgraph.DgraphClientStub(self.TEST_SERVER_ADDR) + self.client = pydgraph.DgraphClient(self.stub) + self.q = ''' + { + company(func: type(x.Company), first: 10){ + expand(_all_) + } + } + ''' + def tearDown(self) -> None: + self.stub.close() + + def test_context_manager_by_contextlib(self): + with self.client.begin(read_only=True, best_effort=True) as tx: + response = tx.query(self.q) + self.assertIsNotNone(response) + data = json.loads(response.json) + print(data) + + def test_context_manager_by_class(self): + with pydgraph.Txn(self.client, read_only=True, best_effort=True) as tx: + response = tx.query(self.q) + self.assertIsNotNone(response) + data = json.loads(response.json) + print(data) + + def suite(): s = unittest.TestSuite() s.addTest(TestTxn()) s.addTest(TestSPStar()) + s.addTest(TestContextManager()) return s diff --git a/tests/test_upsert_block.py b/tests/test_upsert_block.py index 2629462..3548317 100644 --- a/tests/test_upsert_block.py +++ b/tests/test_upsert_block.py @@ -9,7 +9,7 @@ import logging import unittest -from . import helper +from tests import helper class TestUpsertBlock(helper.ClientIntegrationTestCase): diff --git a/tests/test_util.py b/tests/test_util.py index 8f5aa0e..5ab5644 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -10,7 +10,6 @@ from pydgraph import util - class TestUtil(unittest.TestCase): """Tests util utility functions."""