@@ -1606,16 +1606,22 @@ pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
16061606/// Represents the full input to [`fn behavior`].
16071607struct BehaviorInput {
16081608 behavior : Ident ,
1609+ generics : syn:: Generics ,
16091610 handlers : Vec < HandlerSpec > ,
16101611}
16111612
16121613impl syn:: parse:: Parse for BehaviorInput {
16131614 fn parse ( input : syn:: parse:: ParseStream ) -> syn:: Result < Self > {
16141615 let behavior: Ident = input. parse ( ) ?;
1616+ let generics: syn:: Generics = input. parse ( ) ?;
16151617 let _: Token ! [ , ] = input. parse ( ) ?;
16161618 let raw_handlers = input. parse_terminated ( HandlerSpec :: parse, Token ! [ , ] ) ?;
16171619 let handlers = raw_handlers. into_iter ( ) . collect ( ) ;
1618- Ok ( BehaviorInput { behavior, handlers } )
1620+ Ok ( BehaviorInput {
1621+ behavior,
1622+ generics,
1623+ handlers,
1624+ } )
16191625 }
16201626}
16211627
@@ -1634,20 +1640,118 @@ impl syn::parse::Parse for BehaviorInput {
16341640/// u64,
16351641/// );
16361642/// ```
1643+ ///
1644+ /// This macro also supports generic behaviors:
1645+ /// ```
1646+ /// hyperactor::behavior!(
1647+ /// TestBehavior<T>,
1648+ /// Message<T> { castable = true },
1649+ /// u64,
1650+ /// );
1651+ /// ```
16371652#[ proc_macro]
16381653pub fn behavior ( input : TokenStream ) -> TokenStream {
1639- let BehaviorInput { behavior, handlers } = parse_macro_input ! ( input as BehaviorInput ) ;
1654+ let BehaviorInput {
1655+ behavior,
1656+ generics,
1657+ handlers,
1658+ } = parse_macro_input ! ( input as BehaviorInput ) ;
16401659 let tys = HandlerSpec :: add_indexed ( handlers) ;
16411660
1661+ // Add bounds to generics for Named, Serialize, Deserialize
1662+ let mut bounded_generics = generics. clone ( ) ;
1663+ for param in bounded_generics. type_params_mut ( ) {
1664+ param. bounds . push ( syn:: parse_quote!( hyperactor:: Named ) ) ;
1665+ param. bounds . push ( syn:: parse_quote!( serde:: Serialize ) ) ;
1666+ param. bounds . push ( syn:: parse_quote!( std:: marker:: Send ) ) ;
1667+ param. bounds . push ( syn:: parse_quote!( std:: marker:: Sync ) ) ;
1668+ param. bounds . push ( syn:: parse_quote!( std:: fmt:: Debug ) ) ;
1669+ // Note: lifetime parameters are not *actually* hygienic.
1670+ // https://github.com/rust-lang/rust/issues/54727
1671+ let lifetime =
1672+ syn:: Lifetime :: new ( "'hyperactor_behavior_de" , proc_macro2:: Span :: mixed_site ( ) ) ;
1673+ param
1674+ . bounds
1675+ . push ( syn:: parse_quote!( for <#lifetime> serde:: Deserialize <#lifetime>) ) ;
1676+ }
1677+
1678+ // Split the generics for use in different contexts
1679+ let ( impl_generics, ty_generics, where_clause) = bounded_generics. split_for_impl ( ) ;
1680+
1681+ // Create a combined generics for the Binds impl that includes both A and the behavior's generics
1682+ let mut binds_generics = bounded_generics. clone ( ) ;
1683+ binds_generics. params . insert (
1684+ 0 ,
1685+ syn:: GenericParam :: Type ( syn:: TypeParam {
1686+ attrs : vec ! [ ] ,
1687+ ident : Ident :: new ( "A" , proc_macro2:: Span :: call_site ( ) ) ,
1688+ colon_token : None ,
1689+ bounds : Punctuated :: new ( ) ,
1690+ eq_token : None ,
1691+ default : None ,
1692+ } ) ,
1693+ ) ;
1694+ let ( binds_impl_generics, _, _) = binds_generics. split_for_impl ( ) ;
1695+
1696+ // Determine typename and typehash implementation based on whether we have generics
1697+ let type_params: Vec < _ > = bounded_generics. type_params ( ) . collect ( ) ;
1698+ let has_generics = !type_params. is_empty ( ) ;
1699+
1700+ let ( typename_impl, typehash_impl) = if has_generics {
1701+ // Create format string with placeholders for each generic parameter
1702+ let placeholders = vec ! [ "{}" ; type_params. len( ) ] . join ( ", " ) ;
1703+ let placeholders_format_string = format ! ( "<{}>" , placeholders) ;
1704+ let format_string = quote ! { concat!( std:: module_path!( ) , "::" , stringify!( #behavior) , #placeholders_format_string) } ;
1705+
1706+ let type_param_idents: Vec < _ > = type_params. iter ( ) . map ( |p| & p. ident ) . collect ( ) ;
1707+ (
1708+ quote ! {
1709+ hyperactor:: data:: intern_typename!( Self , #format_string, #( #type_param_idents) , * )
1710+ } ,
1711+ quote ! {
1712+ hyperactor:: cityhasher:: hash( Self :: typename( ) )
1713+ } ,
1714+ )
1715+ } else {
1716+ (
1717+ quote ! {
1718+ concat!( std:: module_path!( ) , "::" , stringify!( #behavior) )
1719+ } ,
1720+ quote ! {
1721+ static TYPEHASH : std:: sync:: LazyLock <u64 > = std:: sync:: LazyLock :: new( || {
1722+ hyperactor:: cityhasher:: hash( <#behavior as hyperactor:: data:: Named >:: typename( ) )
1723+ } ) ;
1724+ * TYPEHASH
1725+ } ,
1726+ )
1727+ } ;
1728+
1729+ let type_param_idents = generics. type_params ( ) . map ( |p| & p. ident ) . collect :: < Vec < _ > > ( ) ;
1730+
16421731 let expanded = quote ! {
16431732 #[ doc = "The generated behavior struct." ]
1644- #[ derive( Debug , hyperactor:: Named , serde:: Serialize , serde:: Deserialize ) ]
1645- pub struct #behavior;
1646- impl hyperactor:: actor:: Referable for #behavior { }
1733+ #[ derive( Debug , serde:: Serialize , serde:: Deserialize ) ]
1734+ pub struct #behavior #impl_generics #where_clause {
1735+ _phantom: std:: marker:: PhantomData <( #( #type_param_idents) , * ) >
1736+ }
16471737
1648- impl <A > hyperactor:: actor:: Binds <A > for #behavior
1738+ impl #impl_generics hyperactor:: Named for #behavior #ty_generics #where_clause {
1739+ fn typename( ) -> & ' static str {
1740+ #typename_impl
1741+ }
1742+
1743+ fn typehash( ) -> u64 {
1744+ #typehash_impl
1745+ }
1746+ }
1747+
1748+ impl #impl_generics hyperactor:: actor:: Referable for #behavior #ty_generics #where_clause { }
1749+
1750+ impl #binds_impl_generics hyperactor:: actor:: Binds <A > for #behavior #ty_generics
16491751 where
1650- A : hyperactor:: Actor #( + hyperactor:: Handler <#tys>) * {
1752+ A : hyperactor:: Actor #( + hyperactor:: Handler <#tys>) * ,
1753+ #where_clause
1754+ {
16511755 fn bind( ports: & hyperactor:: proc:: Ports <A >) {
16521756 #(
16531757 ports. bind:: <#tys>( ) ;
@@ -1656,7 +1760,7 @@ pub fn behavior(input: TokenStream) -> TokenStream {
16561760 }
16571761
16581762 #(
1659- impl hyperactor:: actor:: RemoteHandles <#tys> for #behavior { }
1763+ impl #impl_generics hyperactor:: actor:: RemoteHandles <#tys> for #behavior #ty_generics #where_clause { }
16601764 ) *
16611765 } ;
16621766
0 commit comments