88from time import time
99from botocore .exceptions import ClientError
1010from approvaltests .approvals import verify_as_json
11+ from importlib import reload
1112
1213sys .modules ["trace_forwarder.connection" ] = MagicMock ()
1314sys .modules ["datadog_lambda.wrapper" ] = MagicMock ()
3435 enrich ,
3536 transform ,
3637 split ,
38+ extract_ddtags_from_message ,
3739)
3840from parsing import parse , parse_event_type
3941
@@ -130,12 +132,8 @@ def create_cloudwatch_log_event_from_data(data):
130132
131133
132134class TestLambdaFunctionEndToEnd (unittest .TestCase ):
133- @patch ("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get" )
134- @patch ("base_tags_cache.send_forwarder_internal_metrics" )
135135 @patch ("enhanced_lambda_metrics.LambdaTagsCache.get_cache_from_s3" )
136- def test_datadog_forwarder (
137- self , mock_get_s3_cache , mock_forward_metrics , cw_logs_tags_get
138- ):
136+ def test_datadog_forwarder (self , mock_get_s3_cache ):
139137 mock_get_s3_cache .return_value = (
140138 {
141139 "arn:aws:lambda:sa-east-1:601427279990:function:inferred-spans-python-dev-initsender" : [
@@ -149,15 +147,7 @@ def test_datadog_forwarder(
149147 time (),
150148 )
151149 context = Context ()
152- my_path = os .path .abspath (os .path .dirname (__file__ ))
153- path = os .path .join (my_path , "events/cloudwatch_logs.json" )
154-
155- with open (
156- path ,
157- "r" ,
158- ) as input_file :
159- input_data = input_file .read ()
160-
150+ input_data = self ._get_input_data ()
161151 event = {"awslogs" : {"data" : create_cloudwatch_log_event_from_data (input_data )}}
162152 os .environ ["DD_FETCH_LAMBDA_TAGS" ] = "True"
163153
@@ -170,7 +160,7 @@ def test_datadog_forwarder(
170160
171161 verify_as_json (transformed_events )
172162
173- metrics , logs , trace_payloads = split (transformed_events )
163+ _ , _ , trace_payloads = split (transformed_events )
174164 self .assertEqual (len (trace_payloads ), 1 )
175165
176166 trace_payload = json .loads (trace_payloads [0 ]["message" ])
@@ -204,6 +194,98 @@ def test_datadog_forwarder(
204194
205195 del os .environ ["DD_FETCH_LAMBDA_TAGS" ]
206196
197+ @patch ("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get" )
198+ def test_setting_service_tag_from_log_group_cache (self , cw_logs_tags_get ):
199+ reload (sys .modules ["settings" ])
200+ reload (sys .modules ["parsing" ])
201+ cw_logs_tags_get .return_value = ["service:log_group_service" ]
202+ context = Context ()
203+ input_data = self ._get_input_data ()
204+ event = {"awslogs" : {"data" : create_cloudwatch_log_event_from_data (input_data )}}
205+
206+ normalized_events = parse (event , context )
207+ enriched_events = enrich (normalized_events )
208+ transformed_events = transform (enriched_events )
209+
210+ _ , logs , _ = split (transformed_events )
211+ self .assertEqual (len (logs ), 16 )
212+ for log in logs :
213+ self .assertEqual (log ["service" ], "log_group_service" )
214+
215+ @patch .dict (os .environ , {"DD_TAGS" : "service:dd_tag_service" }, clear = True )
216+ @patch ("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get" )
217+ def test_service_override_from_dd_tags (self , cw_logs_tags_get ):
218+ reload (sys .modules ["settings" ])
219+ reload (sys .modules ["parsing" ])
220+ cw_logs_tags_get .return_value = ["service:log_group_service" ]
221+ context = Context ()
222+ input_data = self ._get_input_data ()
223+ event = {"awslogs" : {"data" : create_cloudwatch_log_event_from_data (input_data )}}
224+
225+ normalized_events = parse (event , context )
226+ enriched_events = enrich (normalized_events )
227+ transformed_events = transform (enriched_events )
228+
229+ _ , logs , _ = split (transformed_events )
230+ self .assertEqual (len (logs ), 16 )
231+ for log in logs :
232+ self .assertEqual (log ["service" ], "dd_tag_service" )
233+
234+ @patch ("lambda_cache.LambdaTagsCache.get" )
235+ @patch ("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get" )
236+ def test_overrding_service_tag_from_lambda_cache (
237+ self , lambda_tags_get , cw_logs_tags_get
238+ ):
239+ lambda_tags_get .return_value = ["service:lambda_service" ]
240+ cw_logs_tags_get .return_value = ["service:log_group_service" ]
241+
242+ context = Context ()
243+ input_data = self ._get_input_data ()
244+ event = {"awslogs" : {"data" : create_cloudwatch_log_event_from_data (input_data )}}
245+
246+ normalized_events = parse (event , context )
247+ enriched_events = enrich (normalized_events )
248+ transformed_events = transform (enriched_events )
249+
250+ _ , logs , _ = split (transformed_events )
251+ self .assertEqual (len (logs ), 16 )
252+ for log in logs :
253+ self .assertEqual (log ["service" ], "lambda_service" )
254+
255+ @patch .dict (os .environ , {"DD_TAGS" : "service:dd_tag_service" }, clear = True )
256+ @patch ("lambda_cache.LambdaTagsCache.get" )
257+ @patch ("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get" )
258+ def test_overrding_service_tag_from_lambda_cache_when_dd_tags_is_set (
259+ self , lambda_tags_get , cw_logs_tags_get
260+ ):
261+ lambda_tags_get .return_value = ["service:lambda_service" ]
262+ cw_logs_tags_get .return_value = ["service:log_group_service" ]
263+
264+ context = Context ()
265+ input_data = self ._get_input_data ()
266+ event = {"awslogs" : {"data" : create_cloudwatch_log_event_from_data (input_data )}}
267+
268+ normalized_events = parse (event , context )
269+ enriched_events = enrich (normalized_events )
270+ transformed_events = transform (enriched_events )
271+
272+ _ , logs , _ = split (transformed_events )
273+ self .assertEqual (len (logs ), 16 )
274+ for log in logs :
275+ self .assertEqual (log ["service" ], "lambda_service" )
276+
277+ def _get_input_data (self ):
278+ my_path = os .path .abspath (os .path .dirname (__file__ ))
279+ path = os .path .join (my_path , "events/cloudwatch_logs.json" )
280+
281+ with open (
282+ path ,
283+ "r" ,
284+ ) as input_file :
285+ input_data = input_file .read ()
286+
287+ return input_data
288+
207289
208290class TestLambdaFunctionExtractTracePayload (unittest .TestCase ):
209291 def test_extract_trace_payload_none_no_trace (self ):
@@ -234,5 +316,105 @@ def test_extract_trace_payload_valid_trace(self):
234316 )
235317
236318
319+ class TestMergeMessageTags (unittest .TestCase ):
320+ message_tags = '{"ddtags":"service:my_application_service,custom_tag_1:value1"}'
321+ custom_tags = "custom_tag_2:value2,service:my_custom_service"
322+
323+ def test_extract_ddtags_from_message_str (self ):
324+ event = {
325+ "message" : self .message_tags ,
326+ "ddtags" : self .custom_tags ,
327+ "service" : "my_service" ,
328+ }
329+
330+ extract_ddtags_from_message (event )
331+
332+ self .assertEqual (
333+ event ["ddtags" ],
334+ "custom_tag_2:value2,service:my_application_service,custom_tag_1:value1" ,
335+ )
336+ self .assertEqual (
337+ event ["service" ],
338+ "my_application_service" ,
339+ )
340+
341+ def test_extract_ddtags_from_message_dict (self ):
342+ loaded_message_tags = json .loads (self .message_tags )
343+ event = {
344+ "message" : loaded_message_tags ,
345+ "ddtags" : self .custom_tags ,
346+ "service" : "my_service" ,
347+ }
348+
349+ extract_ddtags_from_message (event )
350+
351+ self .assertEqual (
352+ event ["ddtags" ],
353+ "custom_tag_2:value2,service:my_application_service,custom_tag_1:value1" ,
354+ )
355+ self .assertEqual (
356+ event ["service" ],
357+ "my_application_service" ,
358+ )
359+
360+ def test_extract_ddtags_from_message_service_tag_setting (self ):
361+ loaded_message_tags = json .loads (self .message_tags )
362+ loaded_message_tags ["ddtags" ] = "," .join (
363+ [
364+ tag
365+ for tag in loaded_message_tags ["ddtags" ].split ("," )
366+ if not tag .startswith ("service:" )
367+ ]
368+ )
369+ event = {
370+ "message" : loaded_message_tags ,
371+ "ddtags" : self .custom_tags ,
372+ "service" : "my_custom_service" ,
373+ }
374+
375+ extract_ddtags_from_message (event )
376+
377+ self .assertEqual (
378+ event ["ddtags" ],
379+ "custom_tag_2:value2,service:my_custom_service,custom_tag_1:value1" ,
380+ )
381+ self .assertEqual (
382+ event ["service" ],
383+ "my_custom_service" ,
384+ )
385+
386+ def test_extract_ddtags_from_message_multiple_service_tag_values (self ):
387+ custom_tags = self .custom_tags + ",service:my_custom_service_2"
388+ event = {"message" : self .message_tags , "ddtags" : custom_tags }
389+
390+ extract_ddtags_from_message (event )
391+
392+ self .assertEqual (
393+ event ["ddtags" ],
394+ "custom_tag_2:value2,service:my_application_service,custom_tag_1:value1" ,
395+ )
396+ self .assertEqual (
397+ event ["service" ],
398+ "my_application_service" ,
399+ )
400+
401+ def test_extract_ddtags_from_message_multiple_values_tag (self ):
402+ loaded_message_tags = json .loads (self .message_tags )
403+ loaded_message_tags ["ddtags" ] += ",custom_tag_3:value4"
404+ custom_tags = self .custom_tags + ",custom_tag_3:value3"
405+ event = {"message" : loaded_message_tags , "ddtags" : custom_tags }
406+
407+ extract_ddtags_from_message (event )
408+
409+ self .assertEqual (
410+ event ["ddtags" ],
411+ "custom_tag_2:value2,custom_tag_3:value3,service:my_application_service,custom_tag_1:value1,custom_tag_3:value4" ,
412+ )
413+ self .assertEqual (
414+ event ["service" ],
415+ "my_application_service" ,
416+ )
417+
418+
237419if __name__ == "__main__" :
238420 unittest .main ()
0 commit comments