Skip to content

Commit 75dc6b4

Browse files
committed
Add a 'tags' endpoint to support modifying tags after a flow/run/step/task/artifact
has been produced
1 parent af4f332 commit 75dc6b4

File tree

4 files changed

+210
-48
lines changed

4 files changed

+210
-48
lines changed

services/data/postgres_async_db.py

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREAT
9595

9696
break # Break the retry loop
9797
except Exception as e:
98-
self.logger.exception("Exception occured")
98+
self.logger.exception("Exception occurred")
9999
if retries - i <= 1:
100100
raise e
101101
time.sleep(connection_retry_wait_time_seconds)
@@ -466,6 +466,10 @@ class AsyncFlowTablePostgres(AsyncPostgresTable):
466466
)
467467
_row_type = FlowRow
468468

469+
@staticmethod
470+
def get_filter_dict(flow_id: str):
471+
return {"flow_id": flow_id}
472+
469473
async def add_flow(self, flow: FlowRow):
470474
dict = {
471475
"flow_id": flow.flow_id,
@@ -476,7 +480,7 @@ async def add_flow(self, flow: FlowRow):
476480
return await self.create_record(dict)
477481

478482
async def get_flow(self, flow_id: str):
479-
filter_dict = {"flow_id": flow_id}
483+
filter_dict = self.get_filter_dict(flow_id)
480484
return await self.get_records(filter_dict=filter_dict, fetch_single=True)
481485

482486
async def get_all_flows(self):
@@ -523,9 +527,13 @@ async def add_run(self, run: RunRow):
523527
}
524528
return await self.create_record(dict)
525529

526-
async def get_run(self, flow_id: str, run_id: str, expanded: bool = False):
530+
@staticmethod
531+
def get_filter_dict(flow_id: str, run_id: str):
527532
key, value = translate_run_key(run_id)
528-
filter_dict = {"flow_id": flow_id, key: str(value)}
533+
return {"flow_id": flow_id, key: str(value)}
534+
535+
async def get_run(self, flow_id: str, run_id: str, expanded: bool = False):
536+
filter_dict = self.get_filter_dict(flow_id, run_id)
529537
return await self.get_records(filter_dict=filter_dict,
530538
fetch_single=True, expanded=expanded)
531539

@@ -534,9 +542,7 @@ async def get_all_runs(self, flow_id: str):
534542
return await self.get_records(filter_dict=filter_dict)
535543

536544
async def update_heartbeat(self, flow_id: str, run_id: str):
537-
run_key, run_value = translate_run_key(run_id)
538-
filter_dict = {"flow_id": flow_id,
539-
run_key: str(run_value)}
545+
filter_dict = self.get_filter_dict(flow_id, run_id)
540546
set_dict = {
541547
"last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp())
542548
}
@@ -589,19 +595,23 @@ async def add_step(self, step_object: StepRow):
589595
}
590596
return await self.create_record(dict)
591597

598+
@staticmethod
599+
def get_filter_dict(flow_id: str, run_id: str, step_name: str):
600+
run_id_key, run_id_value = translate_run_key(run_id)
601+
return {
602+
"flow_id": flow_id,
603+
run_id_key: run_id_value,
604+
"step_name": step_name,
605+
}
606+
592607
async def get_steps(self, flow_id: str, run_id: str):
593608
run_id_key, run_id_value = translate_run_key(run_id)
594609
filter_dict = {"flow_id": flow_id,
595610
run_id_key: run_id_value}
596611
return await self.get_records(filter_dict=filter_dict)
597612

598613
async def get_step(self, flow_id: str, run_id: str, step_name: str):
599-
run_id_key, run_id_value = translate_run_key(run_id)
600-
filter_dict = {
601-
"flow_id": flow_id,
602-
run_id_key: run_id_value,
603-
"step_name": step_name,
604-
}
614+
filter_dict = self.get_filter_dict(flow_id, run_id, step_name)
605615
return await self.get_records(filter_dict=filter_dict, fetch_single=True)
606616

607617

@@ -651,36 +661,35 @@ async def add_task(self, task: TaskRow):
651661
}
652662
return await self.create_record(dict)
653663

654-
async def get_tasks(self, flow_id: str, run_id: str, step_name: str):
664+
@staticmethod
665+
def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str):
655666
run_id_key, run_id_value = translate_run_key(run_id)
656-
filter_dict = {
667+
task_id_key, task_id_value = translate_task_key(task_id)
668+
return {
657669
"flow_id": flow_id,
658670
run_id_key: run_id_value,
659671
"step_name": step_name,
672+
task_id_key: task_id_value,
660673
}
661-
return await self.get_records(filter_dict=filter_dict)
662674

663-
async def get_task(self, flow_id: str, run_id: str, step_name: str,
664-
task_id: str, expanded: bool = False):
675+
async def get_tasks(self, flow_id: str, run_id: str, step_name: str):
665676
run_id_key, run_id_value = translate_run_key(run_id)
666-
task_id_key, task_id_value = translate_task_key(task_id)
667677
filter_dict = {
668678
"flow_id": flow_id,
669679
run_id_key: run_id_value,
670680
"step_name": step_name,
671-
task_id_key: task_id_value,
672681
}
682+
return await self.get_records(filter_dict=filter_dict)
683+
684+
async def get_task(self, flow_id: str, run_id: str, step_name: str,
685+
task_id: str, expanded: bool = False):
686+
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
673687
return await self.get_records(filter_dict=filter_dict,
674688
fetch_single=True, expanded=expanded)
675689

676690
async def update_heartbeat(self, flow_id: str, run_id: str, step_name: str,
677691
task_id: str):
678-
run_key, run_value = translate_run_key(run_id)
679-
task_key, task_value = translate_task_key(task_id)
680-
filter_dict = {"flow_id": flow_id,
681-
run_key: str(run_value),
682-
"step_name": step_name,
683-
task_key: str(task_value)}
692+
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
684693
set_dict = {
685694
"last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp())
686695
}
@@ -757,23 +766,27 @@ async def add_metadata(
757766
}
758767
return await self.create_record(dict)
759768

769+
@staticmethod
770+
def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str):
771+
run_id_key, run_id_value = translate_run_key(run_id)
772+
task_id_key, task_id_value = translate_task_key(task_id)
773+
return {
774+
"flow_id": flow_id,
775+
run_id_key: run_id_value,
776+
"step_name": step_name,
777+
task_id_key: task_id_value,
778+
}
779+
760780
async def get_metadata_in_runs(self, flow_id: str, run_id: str):
761781
run_id_key, run_id_value = translate_run_key(run_id)
762782
filter_dict = {"flow_id": flow_id,
763783
run_id_key: run_id_value}
764784
return await self.get_records(filter_dict=filter_dict)
765785

766786
async def get_metadata(
767-
self, flow_id: str, run_id: int, step_name: str, task_id: str
787+
self, flow_id: str, run_id: str, step_name: str, task_id: str
768788
):
769-
run_id_key, run_id_value = translate_run_key(run_id)
770-
task_id_key, task_id_value = translate_task_key(task_id)
771-
filter_dict = {
772-
"flow_id": flow_id,
773-
run_id_key: run_id_value,
774-
"step_name": step_name,
775-
task_id_key: task_id_value,
776-
}
789+
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
777790
return await self.get_records(filter_dict=filter_dict)
778791

779792

@@ -856,7 +869,20 @@ async def add_artifact(
856869
}
857870
return await self.create_record(dict)
858871

859-
async def get_artifacts_in_runs(self, flow_id: str, run_id: int):
872+
@staticmethod
873+
def get_filter_dict(
874+
flow_id: str, run_id: str, step_name: str, task_id: str, name: str):
875+
run_id_key, run_id_value = translate_run_key(run_id)
876+
task_id_key, task_id_value = translate_task_key(task_id)
877+
return {
878+
"flow_id": flow_id,
879+
run_id_key: run_id_value,
880+
"step_name": step_name,
881+
task_id_key: task_id_value,
882+
'"name"': name,
883+
}
884+
885+
async def get_artifacts_in_runs(self, flow_id: str, run_id: str):
860886
run_id_key, run_id_value = translate_run_key(run_id)
861887
filter_dict = {
862888
"flow_id": flow_id,
@@ -865,7 +891,7 @@ async def get_artifacts_in_runs(self, flow_id: str, run_id: int):
865891
return await self.get_records(filter_dict=filter_dict,
866892
ordering=self.ordering)
867893

868-
async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str):
894+
async def get_artifact_in_steps(self, flow_id: str, run_id: str, step_name: str):
869895
run_id_key, run_id_value = translate_run_key(run_id)
870896
filter_dict = {
871897
"flow_id": flow_id,
@@ -876,7 +902,7 @@ async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str)
876902
ordering=self.ordering)
877903

878904
async def get_artifact_in_task(
879-
self, flow_id: str, run_id: int, step_name: str, task_id: int
905+
self, flow_id: str, run_id: str, step_name: str, task_id: str
880906
):
881907
run_id_key, run_id_value = translate_run_key(run_id)
882908
task_id_key, task_id_value = translate_task_key(task_id)
@@ -890,16 +916,8 @@ async def get_artifact_in_task(
890916
ordering=self.ordering)
891917

892918
async def get_artifact(
893-
self, flow_id: str, run_id: int, step_name: str, task_id: int, name: str
919+
self, flow_id: str, run_id: str, step_name: str, task_id: str, name: str
894920
):
895-
run_id_key, run_id_value = translate_run_key(run_id)
896-
task_id_key, task_id_value = translate_task_key(task_id)
897-
filter_dict = {
898-
"flow_id": flow_id,
899-
run_id_key: run_id_value,
900-
"step_name": step_name,
901-
task_id_key: task_id_value,
902-
'"name"': name,
903-
}
921+
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id, name)
904922
return await self.get_records(filter_dict=filter_dict,
905923
fetch_single=True, ordering=self.ordering)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from services.data import TaskRow
2+
from services.data.db_utils import DBResponse
3+
from services.data.postgres_async_db import AsyncPostgresDB
4+
from services.metadata_service.api.utils import format_response, \
5+
handle_exceptions
6+
import json
7+
from aiohttp import web
8+
9+
import asyncio
10+
11+
12+
class TagApi(object):
13+
lock = asyncio.Lock()
14+
15+
def __init__(self, app):
16+
app.router.add_route(
17+
"POST",
18+
"/tags",
19+
self.update_tags,
20+
)
21+
self._db = AsyncPostgresDB.get_instance()
22+
23+
def _get_table(self, type):
24+
if type == 'flow':
25+
return self._db.flow_table_postgres
26+
elif type == 'run':
27+
return self._db.run_table_postgres
28+
elif type == 'step':
29+
return self._db.step_table_postgres
30+
elif type == 'task':
31+
return self._db.task_table_postgres
32+
elif type == 'artifact':
33+
return self._db.artifact_table_postgres
34+
else:
35+
raise ValueError("cannot find table for type %s" % type)
36+
37+
@handle_exceptions
38+
@format_response
39+
async def update_tags(self, request):
40+
"""
41+
---
42+
description: Update user-tags for objects
43+
tags:
44+
- Tags
45+
parameters:
46+
- name: "body"
47+
in: "body"
48+
description: "body"
49+
required: true
50+
schema:
51+
type: array
52+
items:
53+
type: object
54+
required:
55+
- object_type
56+
- id
57+
- tag
58+
- operation
59+
properties:
60+
object_type:
61+
type: string
62+
enum: [flow, run, step, task, artifact]
63+
id:
64+
type: string
65+
operation:
66+
type: string
67+
enum: [add, remove]
68+
tag:
69+
type: string
70+
user:
71+
type: string
72+
produces:
73+
- application/json
74+
responses:
75+
"202":
76+
description: successful operation. Return newly registered task
77+
"404":
78+
description: not found
79+
"500":
80+
description: internal server error
81+
"""
82+
body = await request.json()
83+
results = []
84+
for o in body:
85+
try:
86+
table = self._get_table(o['object_type'])
87+
pathspec = o['id'].split('/')
88+
# Do some basic verification
89+
if o['object_type'] == 'flow' and len(pathspec) != 1:
90+
raise ValueError("invalid flow specification: %s" % o['id'])
91+
elif o['object_type'] == 'run' and len(pathspec) != 2:
92+
raise ValueError("invalid run specification: %s" % o['id'])
93+
elif o['object_type'] == 'step' and len(pathspec) != 3:
94+
raise ValueError("invalid step specification: %s" % o['id'])
95+
elif o['object_type'] == 'task' and len(pathspec) != 4:
96+
raise ValueError("invalid task specification: %s" % o['id'])
97+
elif o['object_type'] == 'artifact' and len(pathspec) != 5:
98+
raise ValueError("invalid artifact specification: %s" % o['id'])
99+
obj_filter = table.get_filter_dict(*pathspec)
100+
except ValueError as e:
101+
return web.Response(status=400, body=json.dumps(
102+
{"message": "invalid input: %s" % str(e)}))
103+
104+
# Now we can get the object
105+
obj = await table.get_records(
106+
filter_dict=obj_filter, fetch_single=True, expanded=True)
107+
if obj.response_code != 200:
108+
return web.Response(status=obj.response_code, body=json.dumps(
109+
{"message": "could not get object %s: %s" % (o['id'], obj.body)}))
110+
111+
# At this point do some checks and update the tags
112+
obj = obj.body
113+
modified = False
114+
if o['operation'] == 'add':
115+
# This is the only error we fail hard on; adding a tag that is
116+
# in system tag
117+
if o['tag'] in obj['system_tags']:
118+
return web.Response(status=405, body=json.dumps(
119+
{"message": "tag %s is already a system tag and can't be added to %s"
120+
% (o['tag'], o['id'])}))
121+
if o['tag'] not in obj['tags']:
122+
modified = True
123+
obj['tags'].append(o['tag'])
124+
elif o['operation'] == 'remove':
125+
if o['tag'] in obj['tags']:
126+
modified = True
127+
obj['tags'].remove(o['tag'])
128+
else:
129+
return web.Response(status=400, body=json.dumps(
130+
{"message": "invalid tag operation %s" % o['operation']}))
131+
if modified:
132+
# We save the value back
133+
result = await table.update_row(filter_dict=obj_filter, update_dict={
134+
'tags': "'%s'" % json.dumps(obj['tags'])})
135+
if result.response_code != 200:
136+
return web.Response(status=result.response_code, body=json.dumps(
137+
{"message": "error updating tags for %s: %s" % (o['id'], result.body)}))
138+
results.append(obj)
139+
140+
return DBResponse(response_code=200, body=results)

services/metadata_service/api/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def format_response(func):
2020
@wraps(func)
2121
async def wrapper(*args, **kwargs):
2222
db_response = await func(*args, **kwargs)
23+
if isinstance(db_response, web.Response):
24+
return db_response
2325
return web.Response(status=db_response.response_code,
2426
body=json.dumps(db_response.body),
2527
headers=MultiDict(

services/metadata_service/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .api.task import TaskApi
1212
from .api.artifact import ArtificatsApi
1313
from .api.admin import AuthApi
14+
from .api.tag import TagApi
1415

1516
from .api.metadata import MetadataApi
1617
from services.data.postgres_async_db import AsyncPostgresDB
@@ -30,6 +31,7 @@ def app(loop=None, db_conf: DBConfiguration = None):
3031
MetadataApi(app)
3132
ArtificatsApi(app)
3233
AuthApi(app)
34+
TagApi(app)
3335
setup_swagger(app)
3436
return app
3537

0 commit comments

Comments
 (0)