|
| 1 | +// Copyright (c) 2025, Berkan Tali |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "nav2_mppi_controller/critics/path_hug_critic.hpp" |
| 16 | + |
| 17 | +#include <Eigen/Dense> |
| 18 | +#include <math.h> |
| 19 | +#include <limits> |
| 20 | + |
| 21 | +#include "nav2_util/geometry_utils.hpp" |
| 22 | + |
| 23 | +namespace mppi::critics |
| 24 | +{ |
| 25 | + |
| 26 | +void PathHugCritic::initialize() |
| 27 | +{ |
| 28 | + auto getParentParam = parameters_handler_->getParamGetter(parent_name_); |
| 29 | + auto getParam = parameters_handler_->getParamGetter(name_); |
| 30 | + getParam(power_, "cost_power", 1); |
| 31 | + getParam(weight_, "cost_weight", 10.0f); |
| 32 | + getParam(search_window_, "search_window", 2.0); |
| 33 | + getParam(sample_stride_, "sample_stride", 3); |
| 34 | + |
| 35 | + RCLCPP_INFO( |
| 36 | + logger_, |
| 37 | + "PathHugCritic instantiated with power=%u, weight=%.2f, search_window=%.2f, stride=%d", |
| 38 | + power_, weight_, search_window_, sample_stride_); |
| 39 | +} |
| 40 | + |
| 41 | +void PathHugCritic::score(CriticData & data) |
| 42 | +{ |
| 43 | + if (!enabled_) { |
| 44 | + return; |
| 45 | + } |
| 46 | + if (data.path.x.size() < 2) { |
| 47 | + return; |
| 48 | + } |
| 49 | + |
| 50 | + if (static_cast<Eigen::Index>(data.path.x.size()) != |
| 51 | + static_cast<Eigen::Index>(path_size_cache_) || |
| 52 | + (data.path.x.size() > 0 && path_x_cache_.size() == data.path.x.size() && |
| 53 | + (data.path.x.array() != path_x_cache_.array()).any())) |
| 54 | + { |
| 55 | + updatePathCache(data.path); |
| 56 | + } |
| 57 | + |
| 58 | + const auto & traj_x = data.trajectories.x; |
| 59 | + const auto & traj_y = data.trajectories.y; |
| 60 | + const Eigen::Index batch_size = traj_x.rows(); |
| 61 | + const Eigen::Index traj_length = traj_x.cols(); |
| 62 | + |
| 63 | + // Pre-allocate arrays |
| 64 | + if (cost_array_.size() != batch_size) { |
| 65 | + cost_array_.resize(batch_size); |
| 66 | + closest_indices_.resize(batch_size); |
| 67 | + closest_indices_.setZero(); |
| 68 | + } |
| 69 | + cost_array_.setZero(); |
| 70 | + |
| 71 | + // Sample trajectory points with stride |
| 72 | + const int effective_stride = std::max(1, std::min(sample_stride_, static_cast<int>(traj_length))); |
| 73 | + const int num_samples = (traj_length + effective_stride - 1) / effective_stride; |
| 74 | + |
| 75 | + if (num_samples == 0) {return;} |
| 76 | + |
| 77 | + for (int sample_idx = 0; sample_idx < num_samples; ++sample_idx) { |
| 78 | + const Eigen::Index traj_col = sample_idx * effective_stride; |
| 79 | + if (traj_col >= traj_length) {break;} |
| 80 | + |
| 81 | + // Get all trajectory points at that step |
| 82 | + const auto & points_x = traj_x.col(traj_col); |
| 83 | + const auto & points_y = traj_y.col(traj_col); |
| 84 | + |
| 85 | + computeDistancesToPathVectorized(points_x, points_y, cost_array_); |
| 86 | + } |
| 87 | + |
| 88 | + // Normalize by the number of samples |
| 89 | + cost_array_ /= static_cast<float>(num_samples); |
| 90 | + |
| 91 | + if (power_ > 1u) { |
| 92 | + data.costs += (cost_array_ * weight_).pow(power_); |
| 93 | + } else { |
| 94 | + data.costs += cost_array_ * weight_; |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +void PathHugCritic::updatePathCache(const models::Path & path) |
| 99 | +{ |
| 100 | + path_size_cache_ = path.x.size(); |
| 101 | + path_x_cache_ = path.x; |
| 102 | + path_y_cache_ = path.y; |
| 103 | + |
| 104 | + const Eigen::Index path_size = path.x.size(); |
| 105 | + if (path_size < 2) {return;} |
| 106 | + |
| 107 | + if (segment_lengths_.size() != path_size - 1) { |
| 108 | + segment_lengths_.resize(path_size - 1); |
| 109 | + cumulative_distances_.resize(path_size); |
| 110 | + segment_dx_.resize(path_size - 1); |
| 111 | + segment_dy_.resize(path_size - 1); |
| 112 | + segment_len_sq_.resize(path_size - 1); |
| 113 | + } |
| 114 | + |
| 115 | + cumulative_distances_(0) = 0.0; |
| 116 | + for (Eigen::Index i = 0; i < path_size - 1; ++i) { |
| 117 | + segment_dx_(i) = path.x(i + 1) - path.x(i); |
| 118 | + segment_dy_(i) = path.y(i + 1) - path.y(i); |
| 119 | + segment_len_sq_(i) = segment_dx_(i) * segment_dx_(i) + segment_dy_(i) * segment_dy_(i); |
| 120 | + segment_lengths_(i) = std::sqrt(segment_len_sq_(i)); |
| 121 | + cumulative_distances_(i + 1) = cumulative_distances_(i) + segment_lengths_(i); |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +void PathHugCritic::computeDistancesToPathVectorized( |
| 126 | + const Eigen::ArrayXf & points_x, |
| 127 | + const Eigen::ArrayXf & points_y, |
| 128 | + Eigen::ArrayXf & distances) |
| 129 | +{ |
| 130 | + const Eigen::Index batch_size = points_x.size(); |
| 131 | + const Eigen::Index path_size = path_x_cache_.size(); |
| 132 | + |
| 133 | + for (Eigen::Index traj_idx = 0; traj_idx < batch_size; ++traj_idx) { |
| 134 | + const float px = points_x(traj_idx); |
| 135 | + const float py = points_y(traj_idx); |
| 136 | + |
| 137 | + float min_dist_sq = std::numeric_limits<float>::max(); |
| 138 | + |
| 139 | + Eigen::Index start_idx = (traj_idx < closest_indices_.size()) ? |
| 140 | + std::max(0L, closest_indices_(traj_idx) - 2) : 0; |
| 141 | + |
| 142 | + double distance_traversed = 0.0; |
| 143 | + Eigen::Index closest_seg_idx = start_idx; |
| 144 | + |
| 145 | + for (Eigen::Index seg_idx = start_idx; seg_idx < path_size - 1; ++seg_idx) { |
| 146 | + if (distance_traversed > search_window_) { |
| 147 | + break; |
| 148 | + } |
| 149 | + |
| 150 | + if (segment_len_sq_(seg_idx) < 1e-6) { |
| 151 | + continue; |
| 152 | + } |
| 153 | + |
| 154 | + // Standard point-to-line-segment distance calculation |
| 155 | + const float dx_to_start = px - path_x_cache_(seg_idx); |
| 156 | + const float dy_to_start = py - path_y_cache_(seg_idx); |
| 157 | + const float dot = dx_to_start * segment_dx_(seg_idx) + dy_to_start * segment_dy_(seg_idx); |
| 158 | + const float t = std::clamp(dot / segment_len_sq_(seg_idx), 0.0f, 1.0f); |
| 159 | + const float proj_x = path_x_cache_(seg_idx) + t * segment_dx_(seg_idx); |
| 160 | + const float proj_y = path_y_cache_(seg_idx) + t * segment_dy_(seg_idx); |
| 161 | + const float dist_sq = (px - proj_x) * (px - proj_x) + (py - proj_y) * (py - proj_y); |
| 162 | + |
| 163 | + if (dist_sq < min_dist_sq) { |
| 164 | + min_dist_sq = dist_sq; |
| 165 | + closest_seg_idx = seg_idx; |
| 166 | + } |
| 167 | + |
| 168 | + distance_traversed += segment_lengths_(seg_idx); |
| 169 | + } |
| 170 | + |
| 171 | + float final_min_dist = std::sqrt(min_dist_sq); |
| 172 | + |
| 173 | + // Fallback if the search window found nothing |
| 174 | + if (min_dist_sq == std::numeric_limits<float>::max()) { |
| 175 | + // If search window fails, do a full search using the same accurate logic |
| 176 | + for (Eigen::Index seg_idx = 0; seg_idx < path_size - 1; ++seg_idx) { |
| 177 | + if (segment_len_sq_(seg_idx) < 1e-6) {continue;} |
| 178 | + |
| 179 | + const float dx_to_start = px - path_x_cache_(seg_idx); |
| 180 | + const float dy_to_start = py - path_y_cache_(seg_idx); |
| 181 | + const float dot = dx_to_start * segment_dx_(seg_idx) + dy_to_start * segment_dy_(seg_idx); |
| 182 | + const float t = std::clamp(dot / segment_len_sq_(seg_idx), 0.0f, 1.0f); |
| 183 | + const float proj_x = path_x_cache_(seg_idx) + t * segment_dx_(seg_idx); |
| 184 | + const float proj_y = path_y_cache_(seg_idx) + t * segment_dy_(seg_idx); |
| 185 | + const float dist_sq = (px - proj_x) * (px - proj_x) + (py - proj_y) * (py - proj_y); |
| 186 | + |
| 187 | + if (dist_sq < min_dist_sq) { |
| 188 | + min_dist_sq = dist_sq; |
| 189 | + closest_seg_idx = seg_idx; |
| 190 | + } |
| 191 | + } |
| 192 | + // Recalculate the final distance from the true minimum squared distance |
| 193 | + final_min_dist = std::sqrt(min_dist_sq); |
| 194 | + } |
| 195 | + |
| 196 | + distances(traj_idx) += final_min_dist; |
| 197 | + if (traj_idx < closest_indices_.size()) { |
| 198 | + closest_indices_(traj_idx) = closest_seg_idx; |
| 199 | + } |
| 200 | + } |
| 201 | +} |
| 202 | +} // namespace mppi::critics |
| 203 | + |
| 204 | +#include <pluginlib/class_list_macros.hpp> |
| 205 | + |
| 206 | +PLUGINLIB_EXPORT_CLASS( |
| 207 | + mppi::critics::PathHugCritic, |
| 208 | + mppi::critics::CriticFunction) |
0 commit comments