@@ -125,6 +125,46 @@ class Answer(pydantic.BaseModel):
125125        )
126126
127127
128+ @pytest .mark .qualitative  
129+ def  test_async_parallel_requests (session ):
130+     async  def  parallel_requests ():
131+         model_opts  =  {ModelOption .STREAM : True }
132+         mot1  =  session .backend .generate_from_context (CBlock ("Say Hello." ), SimpleContext (), model_options = model_opts )
133+         mot2  =  session .backend .generate_from_context (CBlock ("Say Goodbye!" ), SimpleContext (), model_options = model_opts )
134+ 
135+         m1_val  =  None 
136+         m2_val  =  None 
137+         if  not  mot1 .is_computed ():
138+             m1_val  =  await  mot1 .astream ()
139+         if  not  mot2 .is_computed ():
140+             m2_val  =  await  mot2 .astream ()
141+ 
142+         assert  m1_val  is  not   None , "should be a string val after generation" 
143+         assert  m2_val  is  not   None , "should be a string val after generation" 
144+ 
145+         m1_final_val  =  await  mot1 .avalue ()
146+         m2_final_val  =  await  mot2 .avalue ()
147+ 
148+         # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response 
149+         # contains the full response. 
150+         assert  m1_final_val .startswith (m1_val ), "final val should contain the first streamed chunk" 
151+         assert  m2_final_val .startswith (m2_val ), "final val should contain the first streamed chunk" 
152+ 
153+         assert  m1_final_val  ==  mot1 .value 
154+         assert  m2_final_val  ==  mot2 .value 
155+     asyncio .run (parallel_requests ())
156+ 
157+ 
158+ @pytest .mark .qualitative  
159+ def  test_async_avalue (session ):
160+     async  def  avalue ():
161+         mot1  =  session .backend .generate_from_context (CBlock ("Say Hello." ), SimpleContext ())
162+         m1_final_val  =  await  mot1 .avalue ()
163+         assert  m1_final_val  is  not   None 
164+         assert  m1_final_val  ==  mot1 .value 
165+     asyncio .run (avalue ())
166+ 
167+ 
128168if  __name__  ==  "__main__" :
129169    import  pytest 
130170
0 commit comments