@@ -59,7 +59,6 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
5959from typing_extensions import Any , Self
6060
6161
62- from mesa_frames .concrete .polars .agentset import AgentSetPolars
6362from mesa_frames .concrete .agents import AgentsDF
6463from mesa_frames .abstract .agents import AgentContainer , AgentSetDF
6564from mesa_frames .abstract .mixin import CopyMixin , DataFrameMixin
@@ -77,12 +76,13 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
7776 Series ,
7877 SpaceCoordinate ,
7978 SpaceCoordinates ,
79+ AgentLike
8080)
8181
8282ESPG = int
8383
8484
85- AgentLike = Union [ AgentSetPolars , pl . DataFrame ]
85+
8686
8787if TYPE_CHECKING :
8888 from mesa_frames .concrete .model import ModelDF
@@ -1050,36 +1050,98 @@ def move_to(
10501050 include_center : bool = True ,
10511051 shuffle : bool = True
10521052 ) -> None :
1053+ """
1054+ Move agents to new positions based on neighborhood ranking.
1055+
1056+ This method determines each agent's potential moves by computing their
1057+ local neighborhood (with optional inclusion of the center cell). For each
1058+ agent, the method ranks all possible moves according to specified attribute(s)
1059+ and rank order(s). If multiple agents contend for the same cell, the method
1060+ applies a tie-breaking approach. Agents can optionally be processed in a
1061+ randomized order to break ties. The final position of each agent is then
1062+ updated in-place.
1063+
1064+ Parameters
1065+ ----------
1066+ agents : AgentLike
1067+ A DataFrame-like structure containing agent information. Must include
1068+ at least the following columns:
1069+ - ``agent_id``: a unique identifier for each agent
1070+ - ``dim_0``, ``dim_1``: the current positions of agents
1071+ - Optionally ``vision`` if ``radius`` is not provided
1072+ attr_names : str or list of str
1073+ The name(s) of the attribute(s) used for ranking the neighborhood cells.
1074+ If multiple attributes are provided, each should have a corresponding
1075+ entry in ``rank_order``.
1076+ rank_order : str or list of str, optional
1077+ The ranking order for each attribute. Accepts:
1078+ - ``"max"`` (default) for descending order
1079+ - ``"min"`` for ascending order
1080+
1081+ If a single string is provided, it is applied to all attributes in
1082+ ``attr_names``.
1083+ radius : int or pl.Series, optional
1084+ The radius (or per-agent radii) defining the neighborhood around agents.
1085+ If not provided, this method attempts to use the ``vision`` column from
1086+ ``agents``. If ``vision`` is not found, a ``ValueError`` is raised.
1087+ include_center : bool, optional
1088+ If ``True`` (default), the agent's current position is included in its
1089+ neighborhood.
1090+ shuffle : bool, optional
1091+ If ``True`` (default), the order of agents is randomized to break ties.
1092+ If ``False``, agents are processed in the order they appear in the data.
1093+
1094+ Returns
1095+ -------
1096+ None
1097+ This method updates agent positions in-place based on the computed best moves.
1098+ """
1099+ # Ensure attr_names and rank_order are lists of the same length
10531100 if isinstance (attr_names , str ):
10541101 attr_names = [attr_names ]
10551102 if isinstance (rank_order , str ):
10561103 rank_order = [rank_order ] * len (attr_names )
10571104 if len (attr_names ) != len (rank_order ):
10581105 raise ValueError ("attr_names and rank_order must have the same length" )
1106+
1107+ # Handle the neighborhood radius
10591108 if radius is None :
10601109 if "vision" in agents .columns :
10611110 radius = agents ["vision" ]
10621111 else :
1063- raise ValueError ("radius must be specified if agents do not have a 'vision' attribute" )
1112+ raise ValueError (
1113+ "radius must be specified if agents do not have a 'vision' attribute"
1114+ )
1115+
1116+ # Get neighborhood and join with cell information
10641117 neighborhood = self .get_neighborhood (
1065- radius = radius ,
1066- agents = agents ,
1118+ radius = radius ,
1119+ agents = agents ,
10671120 include_center = include_center
10681121 )
10691122 neighborhood = neighborhood .join (self .cells , on = ["dim_0" , "dim_1" ])
1123+
1124+ # Determine the agent identifier column
1125+ agent_id_col = "agent_id" if "agent_id" in agents .columns else "unique_id"
1126+
1127+ # Add a column to identify the center agent
1128+ join_result = neighborhood .join (
1129+ agents .select (["dim_0" , "dim_1" , agent_id_col ]),
1130+ left_on = ["dim_0_center" , "dim_1_center" ],
1131+ right_on = ["dim_0" , "dim_1" ]
1132+ )
1133+
10701134 neighborhood = neighborhood .with_columns (
1071- agent_id_center = neighborhood .join (
1072- agents .pos ,
1073- left_on = ["dim_0_center" , "dim_1_center" ],
1074- right_on = ["dim_0" , "dim_1" ],
1075- )["unique_id" ]
1135+ agent_id_center = join_result [agent_id_col ]
10761136 )
1137+
1138+ # Determine the processing order of agents
10771139 if shuffle :
10781140 agent_order = (
10791141 neighborhood
10801142 .unique (subset = ["agent_id_center" ], keep = "first" )
10811143 .select ("agent_id_center" )
1082- .sample (fraction = 1.0 , seed = self .model .random .integers (0 , 2 ** 31 - 1 ))
1144+ .sample (fraction = 1.0 , seed = self .model .random .integers (0 , 2 ** 31 - 1 ))
10831145 .with_row_index ("agent_order" )
10841146 )
10851147 else :
@@ -1089,16 +1151,24 @@ def move_to(
10891151 .with_row_index ("agent_order" )
10901152 .select (["agent_id_center" , "agent_order" ])
10911153 )
1154+
1155+ # Join the processing order with the neighborhood
10921156 neighborhood = neighborhood .join (agent_order , on = "agent_id_center" )
1157+
1158+ # Prepare sorting columns and order
10931159 sort_cols = []
10941160 sort_desc = []
10951161 for attr , order in zip (attr_names , rank_order ):
10961162 sort_cols .append (attr )
10971163 sort_desc .append (order .lower () == "max" )
1164+
1165+ # Sort the neighborhood cells by specified attributes and then by location
10981166 neighborhood = neighborhood .sort (
10991167 sort_cols + ["radius" , "dim_0" , "dim_1" ],
11001168 descending = sort_desc + [False , False , False ]
11011169 )
1170+
1171+ # Join to track if another agent has blocked a cell
11021172 neighborhood = neighborhood .join (
11031173 agent_order .select (
11041174 pl .col ("agent_id_center" ).alias ("agent_id" ),
@@ -1107,39 +1177,71 @@ def move_to(
11071177 on = "agent_id" ,
11081178 how = "left" ,
11091179 ).rename ({"agent_id" : "blocking_agent_id" })
1180+
1181+ # Iteratively select the best moves
11101182 best_moves = pl .DataFrame ()
1111- while len (best_moves ) < len (agents ):
1183+ max_iterations = min (len (agents ) * 2 , 1000 ) # Safeguard against infinite loops
1184+ iteration_count = 0
1185+
1186+ while len (best_moves ) < len (agents ) and iteration_count < max_iterations :
1187+ iteration_count += 1
1188+
1189+ # Count how many times each (dim_0, dim_1) is being claimed
11121190 neighborhood = neighborhood .with_columns (
11131191 priority = pl .col ("agent_order" ).cum_count ().over (["dim_0" , "dim_1" ])
11141192 )
1193+
11151194 new_best_moves = (
11161195 neighborhood .group_by ("agent_id_center" , maintain_order = True )
11171196 .first ()
11181197 .unique (subset = ["dim_0" , "dim_1" ], keep = "first" , maintain_order = True )
11191198 )
1120- condition = pl .col ("blocking_agent_id" ).is_null () | (
1121- pl .col ("blocking_agent_id" ) == pl .col ("agent_id_center" )
1199+
1200+ condition = (
1201+ pl .col ("blocking_agent_id" ).is_null ()
1202+ | (pl .col ("blocking_agent_id" ) == pl .col ("agent_id_center" ))
11221203 )
1204+
11231205 if len (best_moves ) > 0 :
11241206 condition = condition | pl .col ("blocking_agent_id" ).is_in (
11251207 best_moves ["agent_id_center" ]
11261208 )
1209+
11271210 condition = condition & (pl .col ("priority" ) == 1 )
11281211 new_best_moves = new_best_moves .filter (condition )
1212+
11291213 if len (new_best_moves ) == 0 :
11301214 break
1215+
11311216 best_moves = pl .concat ([best_moves , new_best_moves ])
1217+
1218+ # Update neighborhood to exclude agents that already have a move
1219+ # and cells that are already claimed
11321220 neighborhood = neighborhood .filter (
11331221 ~ pl .col ("agent_id_center" ).is_in (best_moves ["agent_id_center" ])
11341222 )
11351223 neighborhood = neighborhood .join (
11361224 best_moves .select (["dim_0" , "dim_1" ]), on = ["dim_0" , "dim_1" ], how = "anti"
11371225 )
1226+
1227+ # Move agents to their new positions
11381228 if len (best_moves ) > 0 :
1139- self .move_agents (
1140- best_moves .sort ("agent_order" )["agent_id_center" ],
1141- best_moves .sort ("agent_order" ).select (["dim_0" , "dim_1" ])
1142- )
1229+ try :
1230+ self .move_agents (
1231+ best_moves .sort ("agent_order" )["agent_id_center" ],
1232+ best_moves .sort ("agent_order" ).select (["dim_0" , "dim_1" ])
1233+ )
1234+ except Exception as e :
1235+ # Check if the agent exists in the model
1236+ available_agents = set (self .model .agents [agent_id_col ].to_list ()) if hasattr (self .model , 'agents' ) else set ()
1237+ missing_agents = [a for a in best_moves ["agent_id_center" ].to_list () if a not in available_agents ]
1238+
1239+ if missing_agents and available_agents :
1240+ raise ValueError (f"Some agents are not present in the model: { missing_agents } " )
1241+ else :
1242+ raise ValueError (f"Error moving agents: { e } " )
1243+
1244+
11431245
11441246
11451247 @property
0 commit comments