Skip to content

Commit df12f59

Browse files
committed
mppi critic addition
Signed-off-by: silanus23 <berkantali23@outlook.com>
1 parent a4bf04b commit df12f59

File tree

4 files changed

+294
-0
lines changed

4 files changed

+294
-0
lines changed

nav2_mppi_controller/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ add_library(mppi_critics SHARED
8484
src/critics/prefer_forward_critic.cpp
8585
src/critics/twirling_critic.cpp
8686
src/critics/velocity_deadband_critic.cpp
87+
src/critics/path_hug_critic.cpp
8788
)
8889
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND APPLE)
8990
# Apple Clang: use C++20 and optimization, omit -fconcepts

nav2_mppi_controller/critics.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,9 @@
4545
<description>mppi critic for restricting command velocities in deadband range</description>
4646
</class>
4747

48+
<class type="mppi::critics::PathHugCritic" base_class_type="mppi::critics::CriticFunction">
49+
<description>mppi critic for restricting command velocities in deadband range</description>
50+
</class>
51+
4852
</library>
4953
</class_libraries>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) 2025, Berkan Tali
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may not use a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef NAV2_MPPI_CONTROLLER__CRITICS__PATH_HUG_CRITIC_HPP_
16+
#define NAV2_MPPI_CONTROLLER__CRITICS__PATH_HUG_CRITIC_HPP_
17+
18+
#include <Eigen/Dense>
19+
#include "nav2_mppi_controller/critic_function.hpp"
20+
#include "nav2_mppi_controller/models/state.hpp"
21+
#include "nav2_mppi_controller/models/path.hpp"
22+
#include "nav2_mppi_controller/models/trajectories.hpp"
23+
24+
namespace mppi::critics
25+
{
26+
27+
/**
28+
* @class mppi::critics::PathHugCritic
29+
* @brief Critic plugin for penalizing trajectories that deviate from the global path.
30+
*
31+
* This critic calculates a cost for each trajectory based on its distance from the
32+
* reference path. The cost is an accumulation of the minimum distances from sampled points
33+
* along the trajectory to the reference path. This encourages the controller to generate
34+
* trajectories that "hug" or closely follow the provided global plan.
35+
*/
36+
class PathHugCritic : public CriticFunction
37+
{
38+
public:
39+
/**
40+
* @brief Initialize the critic, loading parameters.
41+
*/
42+
void initialize() override;
43+
44+
/**
45+
* @brief Evaluate the critic score for all trajectories.
46+
* @param data The critic data object containing information about the trajectories, path, and costs.
47+
*/
48+
void score(CriticData & data) override;
49+
50+
private:
51+
void updatePathCache(const models::Path & path);
52+
53+
void computeDistancesToPathVectorized(
54+
const Eigen::ArrayXf & points_x,
55+
const Eigen::ArrayXf & points_y,
56+
Eigen::ArrayXf & distances);
57+
58+
// Parameters
59+
unsigned int power_;
60+
float weight_;
61+
double search_window_;
62+
int sample_stride_;
63+
64+
// Path cache for efficient computation
65+
size_t path_size_cache_{0};
66+
Eigen::ArrayXf path_x_cache_;
67+
Eigen::ArrayXf path_y_cache_;
68+
Eigen::ArrayXf segment_lengths_;
69+
Eigen::ArrayXf cumulative_distances_;
70+
Eigen::ArrayXf segment_dx_;
71+
Eigen::ArrayXf segment_dy_;
72+
Eigen::ArrayXf segment_len_sq_;
73+
74+
// Per-evaluation cache
75+
Eigen::ArrayXf cost_array_;
76+
Eigen::Array<Eigen::Index, Eigen::Dynamic, 1> closest_indices_;
77+
};
78+
79+
} // namespace mppi::critics
80+
81+
#endif // NAV2_MPPI_CONTROLLER__CRITICS__PATH_HUG_CRITIC_HPP_
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
// Copyright (c) 2025, Berkan Tali
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "nav2_mppi_controller/critics/path_hug_critic.hpp"
16+
17+
#include <Eigen/Dense>
18+
#include <math.h>
19+
#include <limits>
20+
21+
#include "nav2_util/geometry_utils.hpp"
22+
23+
namespace mppi::critics
24+
{
25+
26+
void PathHugCritic::initialize()
27+
{
28+
auto getParentParam = parameters_handler_->getParamGetter(parent_name_);
29+
auto getParam = parameters_handler_->getParamGetter(name_);
30+
getParam(power_, "cost_power", 1);
31+
getParam(weight_, "cost_weight", 10.0f);
32+
getParam(search_window_, "search_window", 2.0);
33+
getParam(sample_stride_, "sample_stride", 3);
34+
35+
RCLCPP_INFO(
36+
logger_,
37+
"PathHugCritic instantiated with power=%u, weight=%.2f, search_window=%.2f, stride=%d",
38+
power_, weight_, search_window_, sample_stride_);
39+
}
40+
41+
void PathHugCritic::score(CriticData & data)
42+
{
43+
if (!enabled_) {
44+
return;
45+
}
46+
if (data.path.x.size() < 2) {
47+
return;
48+
}
49+
50+
if (static_cast<Eigen::Index>(data.path.x.size()) !=
51+
static_cast<Eigen::Index>(path_size_cache_) ||
52+
(data.path.x.size() > 0 && path_x_cache_.size() == data.path.x.size() &&
53+
(data.path.x.array() != path_x_cache_.array()).any()))
54+
{
55+
updatePathCache(data.path);
56+
}
57+
58+
const auto & traj_x = data.trajectories.x;
59+
const auto & traj_y = data.trajectories.y;
60+
const Eigen::Index batch_size = traj_x.rows();
61+
const Eigen::Index traj_length = traj_x.cols();
62+
63+
// Pre-allocate arrays
64+
if (cost_array_.size() != batch_size) {
65+
cost_array_.resize(batch_size);
66+
closest_indices_.resize(batch_size);
67+
closest_indices_.setZero();
68+
}
69+
cost_array_.setZero();
70+
71+
// Sample trajectory points with stride
72+
const int effective_stride = std::max(1, std::min(sample_stride_, static_cast<int>(traj_length)));
73+
const int num_samples = (traj_length + effective_stride - 1) / effective_stride;
74+
75+
if (num_samples == 0) {return;}
76+
77+
for (int sample_idx = 0; sample_idx < num_samples; ++sample_idx) {
78+
const Eigen::Index traj_col = sample_idx * effective_stride;
79+
if (traj_col >= traj_length) {break;}
80+
81+
// Get all trajectory points at that step
82+
const auto & points_x = traj_x.col(traj_col);
83+
const auto & points_y = traj_y.col(traj_col);
84+
85+
computeDistancesToPathVectorized(points_x, points_y, cost_array_);
86+
}
87+
88+
// Normalize by the number of samples
89+
cost_array_ /= static_cast<float>(num_samples);
90+
91+
if (power_ > 1u) {
92+
data.costs += (cost_array_ * weight_).pow(power_);
93+
} else {
94+
data.costs += cost_array_ * weight_;
95+
}
96+
}
97+
98+
void PathHugCritic::updatePathCache(const models::Path & path)
99+
{
100+
path_size_cache_ = path.x.size();
101+
path_x_cache_ = path.x;
102+
path_y_cache_ = path.y;
103+
104+
const Eigen::Index path_size = path.x.size();
105+
if (path_size < 2) {return;}
106+
107+
if (segment_lengths_.size() != path_size - 1) {
108+
segment_lengths_.resize(path_size - 1);
109+
cumulative_distances_.resize(path_size);
110+
segment_dx_.resize(path_size - 1);
111+
segment_dy_.resize(path_size - 1);
112+
segment_len_sq_.resize(path_size - 1);
113+
}
114+
115+
cumulative_distances_(0) = 0.0;
116+
for (Eigen::Index i = 0; i < path_size - 1; ++i) {
117+
segment_dx_(i) = path.x(i + 1) - path.x(i);
118+
segment_dy_(i) = path.y(i + 1) - path.y(i);
119+
segment_len_sq_(i) = segment_dx_(i) * segment_dx_(i) + segment_dy_(i) * segment_dy_(i);
120+
segment_lengths_(i) = std::sqrt(segment_len_sq_(i));
121+
cumulative_distances_(i + 1) = cumulative_distances_(i) + segment_lengths_(i);
122+
}
123+
}
124+
125+
void PathHugCritic::computeDistancesToPathVectorized(
126+
const Eigen::ArrayXf & points_x,
127+
const Eigen::ArrayXf & points_y,
128+
Eigen::ArrayXf & distances)
129+
{
130+
const Eigen::Index batch_size = points_x.size();
131+
const Eigen::Index path_size = path_x_cache_.size();
132+
133+
for (Eigen::Index traj_idx = 0; traj_idx < batch_size; ++traj_idx) {
134+
const float px = points_x(traj_idx);
135+
const float py = points_y(traj_idx);
136+
137+
float min_dist_sq = std::numeric_limits<float>::max();
138+
139+
Eigen::Index start_idx = (traj_idx < closest_indices_.size()) ?
140+
std::max(0L, closest_indices_(traj_idx) - 2) : 0;
141+
142+
double distance_traversed = 0.0;
143+
Eigen::Index closest_seg_idx = start_idx;
144+
145+
for (Eigen::Index seg_idx = start_idx; seg_idx < path_size - 1; ++seg_idx) {
146+
if (distance_traversed > search_window_) {
147+
break;
148+
}
149+
150+
if (segment_len_sq_(seg_idx) < 1e-6) {
151+
continue;
152+
}
153+
154+
// Standard point-to-line-segment distance calculation
155+
const float dx_to_start = px - path_x_cache_(seg_idx);
156+
const float dy_to_start = py - path_y_cache_(seg_idx);
157+
const float dot = dx_to_start * segment_dx_(seg_idx) + dy_to_start * segment_dy_(seg_idx);
158+
const float t = std::clamp(dot / segment_len_sq_(seg_idx), 0.0f, 1.0f);
159+
const float proj_x = path_x_cache_(seg_idx) + t * segment_dx_(seg_idx);
160+
const float proj_y = path_y_cache_(seg_idx) + t * segment_dy_(seg_idx);
161+
const float dist_sq = (px - proj_x) * (px - proj_x) + (py - proj_y) * (py - proj_y);
162+
163+
if (dist_sq < min_dist_sq) {
164+
min_dist_sq = dist_sq;
165+
closest_seg_idx = seg_idx;
166+
}
167+
168+
distance_traversed += segment_lengths_(seg_idx);
169+
}
170+
171+
float final_min_dist = std::sqrt(min_dist_sq);
172+
173+
// Fallback if the search window found nothing
174+
if (min_dist_sq == std::numeric_limits<float>::max()) {
175+
// If search window fails, do a full search using the same accurate logic
176+
for (Eigen::Index seg_idx = 0; seg_idx < path_size - 1; ++seg_idx) {
177+
if (segment_len_sq_(seg_idx) < 1e-6) {continue;}
178+
179+
const float dx_to_start = px - path_x_cache_(seg_idx);
180+
const float dy_to_start = py - path_y_cache_(seg_idx);
181+
const float dot = dx_to_start * segment_dx_(seg_idx) + dy_to_start * segment_dy_(seg_idx);
182+
const float t = std::clamp(dot / segment_len_sq_(seg_idx), 0.0f, 1.0f);
183+
const float proj_x = path_x_cache_(seg_idx) + t * segment_dx_(seg_idx);
184+
const float proj_y = path_y_cache_(seg_idx) + t * segment_dy_(seg_idx);
185+
const float dist_sq = (px - proj_x) * (px - proj_x) + (py - proj_y) * (py - proj_y);
186+
187+
if (dist_sq < min_dist_sq) {
188+
min_dist_sq = dist_sq;
189+
closest_seg_idx = seg_idx;
190+
}
191+
}
192+
// Recalculate the final distance from the true minimum squared distance
193+
final_min_dist = std::sqrt(min_dist_sq);
194+
}
195+
196+
distances(traj_idx) += final_min_dist;
197+
if (traj_idx < closest_indices_.size()) {
198+
closest_indices_(traj_idx) = closest_seg_idx;
199+
}
200+
}
201+
}
202+
} // namespace mppi::critics
203+
204+
#include <pluginlib/class_list_macros.hpp>
205+
206+
PLUGINLIB_EXPORT_CLASS(
207+
mppi::critics::PathHugCritic,
208+
mppi::critics::CriticFunction)

0 commit comments

Comments
 (0)