diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6fd0576..e5830b1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: run: sudo apt update - name: Install dependencies - run: sudo apt install -y libopencv-dev libyaml-cpp-dev libprotobuf-dev libprotoc-dev protobuf-compiler + run: sudo apt install -y libopencv-dev libyaml-cpp-dev libprotobuf-dev libprotoc-dev protobuf-compiler libeigen3-dev - name: Build run: | diff --git a/readme.md b/readme.md index 0c4a390..05b8a86 100644 --- a/readme.md +++ b/readme.md @@ -13,7 +13,7 @@ The assumption is then that the input is a sequence of images. The program outpu Prerequisites: ``` -sudo apt-get install -y build-essential libopencv-dev libyaml-cpp-dev libprotobuf-dev libprotoc-dev protobuf-compiler +sudo apt-get install -y libopencv-dev libyaml-cpp-dev libprotobuf-dev libprotoc-dev protobuf-compiler libeigen3-dev ``` Tested on Ubuntu 20.04. @@ -56,6 +56,17 @@ For more details about the underlying method and the interpretation of the resul Here is a sketch of what roughly is happening for those who don't like to read much ![](doc/cost_matrix_view.png) +## LIDAR place recognition + +The code allows to match also the ScanContext sequences. +To do place recognition in this case, you need for now 3 steps: +1. Convert numpy arrays into protobuf, use `python/convert_numpy_to_scan_context.py` +2. Compute scan context cost matrix, use `./src/apps/cost_matrix_based_matching/compute_scan_context_cost_matrix queryFeaturesDir referenceFeaturesDir outputFilename` from `build` folder +3. Run online place recognition with pre computed cost matrix, use `python/run_scan_context_matching.py` + + +More general support will come later. + ## Parent project This repository is a continuation of my previous works [vpr_relocalization](https://github.com/PRBonn/vpr_relocalization) and [online_place_recognition](https://github.com/PRBonn/online_place_recognition). diff --git a/src/apps/cost_matrix_based_matching/CMakeLists.txt b/src/apps/cost_matrix_based_matching/CMakeLists.txt index 0e62f93..e1c92b1 100644 --- a/src/apps/cost_matrix_based_matching/CMakeLists.txt +++ b/src/apps/cost_matrix_based_matching/CMakeLists.txt @@ -37,3 +37,20 @@ target_link_libraries(localization_by_hashing ${OpenCV_LIBS} ) +add_executable(scan_context_matching scan_context_matching.cpp) +target_link_libraries(scan_context_matching + cost_matrix + config_parser + online_database + successor_manager + online_localizer + scan_context_relocalizer + glog::glog +) + +add_executable(compute_scan_context_cost_matrix compute_scan_context_cost_matrix.cpp) +target_link_libraries(compute_scan_context_cost_matrix + cost_matrix + glog::glog +) + diff --git a/src/apps/cost_matrix_based_matching/compute_scan_context_cost_matrix.cpp b/src/apps/cost_matrix_based_matching/compute_scan_context_cost_matrix.cpp new file mode 100644 index 0000000..f104607 --- /dev/null +++ b/src/apps/cost_matrix_based_matching/compute_scan_context_cost_matrix.cpp @@ -0,0 +1,35 @@ +/* By O. Vysotska in 2023 */ + +#include "database/cost_matrix.h" +#include "database/cost_matrix_database.h" +#include "features/feature_factory.h" +#include "tools/config_parser/config_parser.h" + +#include + +#include + +int main(int argc, char *argv[]) { + // TODO(olga) Add gflags support. + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + LOG(INFO) << "===== Compute Cost Matrix for Scan Context ====\n"; + + if (argc < 4) { + LOG(ERROR) << "Not enough input parameters."; + LOG(INFO) << "Proper usage: ./compute_scan_context_cost_matrix " + "queryFeaturesDir referenceFeaturesDir outputFilename"; + exit(0); + } + + std::string queryFeaturesDir = argv[1]; + std::string refFeaturesDir = argv[2]; + std::string outputFilename = argv[3]; + + const localization::database::CostMatrix costMatrix( + queryFeaturesDir, refFeaturesDir, localization::features::Scan_Context); + costMatrix.storeToProto(outputFilename); + + LOG(INFO) << "Cost matrix is saved to" << outputFilename; + return 0; +} \ No newline at end of file diff --git a/src/apps/cost_matrix_based_matching/online_localizer_lsh.cpp b/src/apps/cost_matrix_based_matching/online_localizer_lsh.cpp index 975de22..805b857 100644 --- a/src/apps/cost_matrix_based_matching/online_localizer_lsh.cpp +++ b/src/apps/cost_matrix_based_matching/online_localizer_lsh.cpp @@ -63,7 +63,7 @@ int main(int argc, char *argv[]) { if (argc < 2) { printf("[ERROR] Not enough input parameters.\n"); - printf("Proper usage: ./cost_matrix_based_matching_lsh config_file.yaml\n"); + printf("Proper usage: ./online_localizer_lsh config_file.yaml\n"); exit(0); } @@ -77,7 +77,7 @@ int main(int argc, char *argv[]) { /*refFeaturesDir=*/parser.path2ref, /*type=*/loc::features::FeatureType::Cnn_Feature, /*bufferSize=*/parser.bufferSize, - /*costMatrixFile=*/parser.costMatrix); + /*costMatrixFile=*/parser.costMatrix, /*invert=*/true); auto relocalizer = std::make_unique( /*onlineDatabase=*/database.get(), diff --git a/src/apps/cost_matrix_based_matching/scan_context_matching.cpp b/src/apps/cost_matrix_based_matching/scan_context_matching.cpp new file mode 100644 index 0000000..0e60128 --- /dev/null +++ b/src/apps/cost_matrix_based_matching/scan_context_matching.cpp @@ -0,0 +1,55 @@ +/* By O. Vysotska in 2023 */ + +#include "database/cost_matrix.h" +#include "database/online_database.h" +#include "features/feature_factory.h" +#include "online_localizer/online_localizer.h" +#include "relocalizers/scan_context_relocalizer.h" +#include "successor_manager/successor_manager.h" +#include "tools/config_parser/config_parser.h" + +#include + +#include +#include + +namespace loc = localization; + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + LOG(INFO) << "===== Matching ScanContext sequences ====\n"; + + if (argc < 2) { + LOG(ERROR) << "Not enough input parameters."; + LOG(INFO) << "Proper usage: ./scan_context_matching " + "config_file.yaml"; + exit(0); + } + + std::string config_file = argv[1]; + ConfigParser parser; + parser.parseYaml(config_file); + parser.print(); + + const auto database = std::make_unique( + parser.path2qu, parser.path2ref, loc::features::FeatureType::Scan_Context, + parser.bufferSize, parser.costMatrix, false); + + const auto relocalizer = + std::make_unique( + parser.path2ref, database.get()); + + const auto successorManager = + std::make_unique( + database.get(), relocalizer.get(), parser.fanOut); + loc::online_localizer::OnlineLocalizer localizer{ + successorManager.get(), parser.expansionRate, parser.nonMatchCost}; + const loc::online_localizer::Matches imageMatches = + localizer.findMatchesTill(parser.querySize); + loc::online_localizer::storeMatchesAsProto(imageMatches, + parser.matchingResult); + + LOG(INFO) << "Done."; + return 0; +} \ No newline at end of file diff --git a/src/localization/database/cost_matrix.cpp b/src/localization/database/cost_matrix.cpp index 929d334..a8cc591 100644 --- a/src/localization/database/cost_matrix.cpp +++ b/src/localization/database/cost_matrix.cpp @@ -28,18 +28,20 @@ CostMatrix::CostMatrix(const std::string &queryFeaturesDir, const std::vector refFeaturesFiles = listProtoDir(refFeaturesDir, ".Feature"); - std::cerr << "Query features" << queryFeaturesFiles.size() << std::endl; - std::cerr << "ref features" << refFeaturesFiles.size() << std::endl; + LOG(INFO) << "Query features size " << queryFeaturesFiles.size(); + LOG(INFO) << "Reference features size " << refFeaturesDir.size(); + LOG(INFO) << "Computing cost matrix. This make take some time..."; costs_.reserve(queryFeaturesFiles.size()); - for (const auto &queryFile : queryFeaturesFiles) { - auto queryFeature = createFeature(type, queryFile); + for (int fileIdx = 0; fileIdx < queryFeaturesFiles.size(); ++fileIdx) { + auto queryFeature = createFeature(type, queryFeaturesFiles[fileIdx]); std::vector row; row.reserve(refFeaturesFiles.size()); for (const auto &refFile : refFeaturesFiles) { const auto refFeature = createFeature(type, refFile); row.push_back(queryFeature->computeSimilarityScore(*refFeature)); } + LOG(INFO) << "Computed row values for query image " << fileIdx; costs_.push_back(row); } rows_ = costs_.size(); @@ -76,17 +78,22 @@ void CostMatrix::loadFromTxt(const std::string &filename, int rows, int cols) { double CostMatrix::at(int row, int col) const { CHECK(row >= 0 && row < rows_) << "Row outside range " << row; CHECK(col >= 0 && col < cols_) << "Col outside range " << col; + if (inverseCosts_) { + return getInverseCost(row, col); + } return costs_[row][col]; } double CostMatrix::getInverseCost(int row, int col) const { - const double value = this->at(row, col); + const double value = costs_[row][col]; if (std::abs(value) < kEpsilon) { return std::numeric_limits::max(); } - if (value < 0){ - LOG(WARNING) << "The cost value for row:" << row << " and col:" << col <<" is < 0: " << value<< ". This should not be like this. I will make a positive value of it for now. But please check your values"; - + if (value < 0) { + LOG(WARNING) << "The cost value for row:" << row << " and col:" << col + << " is < 0: " << value + << ". This should not be like this. I will make a positive " + "value of it for now. But please check your values"; } return 1. / std::abs(value); } @@ -112,4 +119,25 @@ void CostMatrix::loadFromProto(const std::string &filename) { LOG(INFO) << "Read cost matrix with " << rows_ << " rows and " << cols_ << " cols."; } + +void CostMatrix::storeToProto(const std::string &protoFilename) const { + image_sequence_localizer::CostMatrix costMatrixProto; + costMatrixProto.set_cols(cols_); + costMatrixProto.set_rows(rows_); + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + costMatrixProto.add_values(costs_[r][c]); + } + } + + std::fstream out(protoFilename, + std::ios::out | std::ios::trunc | std::ios::binary); + if (!costMatrixProto.SerializeToOstream(&out)) { + LOG(ERROR) << "Couldn't open the file " << protoFilename; + LOG(ERROR) << "The path is NOT saved."; + return; + } + out.close(); + LOG(INFO) << "The cost matrix was written to " << protoFilename; +} } // namespace localization::database diff --git a/src/localization/database/cost_matrix.h b/src/localization/database/cost_matrix.h index d2a2db3..b7957ea 100644 --- a/src/localization/database/cost_matrix.h +++ b/src/localization/database/cost_matrix.h @@ -9,6 +9,7 @@ #include namespace localization::database { + class CostMatrix { public: using Matrix = std::vector>; @@ -21,12 +22,18 @@ class CostMatrix { void loadFromTxt(const std::string &filename, int rows, int cols); + // TODO(olga): This should be removed once ifeature doesn't have explicit + // similarity score requirement. + void inverseCosts(bool inverse = true) { inverseCosts_ = inverse; } + void loadFromProto(const std::string &filename); + void storeToProto(const std::string &protoFilename) const; const Matrix &getCosts() const { return costs_; } double at(int row, int col) const; // Computes 1/value. double getInverseCost(int row, int col) const; + int rows() const { return rows_; } int cols() const { return cols_; } @@ -34,6 +41,7 @@ class CostMatrix { Matrix costs_; int rows_ = 0; int cols_ = 0; + bool inverseCosts_ = true; }; } // namespace localization::database diff --git a/src/localization/database/cost_matrix_database.cpp b/src/localization/database/cost_matrix_database.cpp index f31da53..850bb59 100644 --- a/src/localization/database/cost_matrix_database.cpp +++ b/src/localization/database/cost_matrix_database.cpp @@ -37,7 +37,7 @@ CostMatrixDatabase::CostMatrixDatabase(const std::string &costMatrixFile) : costMatrix_(CostMatrix(costMatrixFile)) {} double CostMatrixDatabase::getCost(int quId, int refId) { - return costMatrix_.getInverseCost(quId, refId); + return costMatrix_.at(quId, refId); } } // namespace localization::database diff --git a/src/localization/database/online_database.cpp b/src/localization/database/online_database.cpp index 094c966..d152598 100644 --- a/src/localization/database/online_database.cpp +++ b/src/localization/database/online_database.cpp @@ -35,6 +35,7 @@ #include #include +#include namespace localization::database { namespace { @@ -54,7 +55,7 @@ addFeatureIfNeeded(features::FeatureBuffer &featureBuffer, OnlineDatabase::OnlineDatabase(const std::string &queryFeaturesDir, const std::string &refFeaturesDir, features::FeatureType type, int bufferSize, - const std::string &costMatrixFile) + const std::string &costMatrixFile, bool invert) : quFeaturesNames_{listProtoDir(queryFeaturesDir, ".Feature")}, refFeaturesNames_{listProtoDir(refFeaturesDir, ".Feature")}, featureType_{type}, refBuffer_{std::make_unique( @@ -64,6 +65,7 @@ OnlineDatabase::OnlineDatabase(const std::string &queryFeaturesDir, LOG_IF(FATAL, refFeaturesNames_.empty()) << "Reference features are not set."; if (!costMatrixFile.empty()) { precomputedCosts_ = CostMatrix(costMatrixFile); + precomputedCosts_->inverseCosts(invert); } } @@ -83,7 +85,7 @@ double OnlineDatabase::computeMatchingCost(int quId, int refId) { double OnlineDatabase::getCost(int quId, int refId) { if (precomputedCosts_) { - return precomputedCosts_->getInverseCost(quId, refId); + return precomputedCosts_->at(quId, refId); } // Check if the cost was computed before. auto rowIter = costs_.find(quId); diff --git a/src/localization/database/online_database.h b/src/localization/database/online_database.h index bf16b62..4749609 100644 --- a/src/localization/database/online_database.h +++ b/src/localization/database/online_database.h @@ -46,7 +46,8 @@ class OnlineDatabase : public iDatabase { public: OnlineDatabase(const std::string &queryFeaturesDir, const std::string &refFeaturesDir, features::FeatureType type, - int bufferSize, const std::string &costMatrixFile = ""); + int bufferSize, const std::string &costMatrixFile = "", + bool invert = true); inline int refSize() override { return refFeaturesNames_.size(); } double getCost(int quId, int refId) override; diff --git a/src/localization/features/CMakeLists.txt b/src/localization/features/CMakeLists.txt index 680acc7..d1537b4 100644 --- a/src/localization/features/CMakeLists.txt +++ b/src/localization/features/CMakeLists.txt @@ -5,12 +5,26 @@ target_link_libraries(cnn_feature glog::glog ) +find_package (Eigen3 3.3 REQUIRED NO_MODULE) + +add_library(scan_context scan_context.cpp) +target_link_libraries(scan_context + glog::glog + cxx_flags + Eigen3::Eigen +) + add_library(feature_factory feature_factory.cpp) target_link_libraries(feature_factory PUBLIC - cnn_feature + cnn_feature + scan_context glog::glog + ) add_library(feature_buffer feature_buffer.cpp) target_link_libraries(feature_buffer glog::glog cxx_flags) + + + diff --git a/src/localization/features/cnn_feature.h b/src/localization/features/cnn_feature.h index 706e6e2..ffbc286 100644 --- a/src/localization/features/cnn_feature.h +++ b/src/localization/features/cnn_feature.h @@ -50,7 +50,7 @@ class CnnFeature : public iFeature { */ double score2cost(double score) const override; - using iFeature::bits; + // using iFeature::bits; using iFeature::dimensions; protected: diff --git a/src/localization/features/feature_factory.cpp b/src/localization/features/feature_factory.cpp index 040cb70..0bfc5b8 100644 --- a/src/localization/features/feature_factory.cpp +++ b/src/localization/features/feature_factory.cpp @@ -23,8 +23,10 @@ #include "feature_factory.h" #include "cnn_feature.h" +#include "features/scan_context.h" #include +#include namespace localization::features { @@ -34,6 +36,9 @@ std::unique_ptr createFeature(FeatureType type, case Cnn_Feature: { return std::make_unique(featureFilename); } + case Scan_Context: { + return std::make_unique(featureFilename); + } } LOG(FATAL) << "Unknown feature type"; } diff --git a/src/localization/features/feature_factory.h b/src/localization/features/feature_factory.h index 66f0c83..d2bf7fd 100644 --- a/src/localization/features/feature_factory.h +++ b/src/localization/features/feature_factory.h @@ -28,9 +28,7 @@ namespace localization::features { -enum FeatureType { - Cnn_Feature, -}; +enum FeatureType { Cnn_Feature, Scan_Context }; std::unique_ptr createFeature(FeatureType type, const std::string &featureFilename); diff --git a/src/localization/features/ifeature.h b/src/localization/features/ifeature.h index 318d825..1628214 100644 --- a/src/localization/features/ifeature.h +++ b/src/localization/features/ifeature.h @@ -35,6 +35,8 @@ namespace localization::features { /** * @brief Interface class for features. */ +// TODO(olga): Make this easier, the feature comparison should returncost +// directly. Don't care about the similarity. class iFeature { public: /** @@ -48,6 +50,7 @@ class iFeature { virtual double computeSimilarityScore(const iFeature &rhs) const = 0; /** * @brief Transforms similarity into the weights/cost for the graph. + The bigger the similarity the smaller the cost should be. * @param[in] score The score * diff --git a/src/localization/features/scan_context.cpp b/src/localization/features/scan_context.cpp new file mode 100644 index 0000000..831a37d --- /dev/null +++ b/src/localization/features/scan_context.cpp @@ -0,0 +1,108 @@ +#include "scan_context.h" +#include "localization_protos.pb.h" + +#include +#include +#include + +#include +#include +#include +#include + +namespace localization::features { + +double cosine_distance(const Eigen::VectorXd &lhs, const Eigen::VectorXd &rhs) { + if (std::abs(lhs.norm()) < 1e-09 || std::abs(rhs.norm()) < 1e-09) { + return 0; + } + return 1 - lhs.dot(rhs) / (lhs.norm() * rhs.norm()); +} + +double computeDistanceBetweenGrids(const Eigen::MatrixXd &lhs, + const Eigen::MatrixXd &rhs) { + CHECK(lhs.cols() > 0) << "Invalid grid size."; + CHECK(lhs.cols() == rhs.cols()) << "The number of columns is not the same."; + CHECK(lhs.rows() == rhs.rows()) << "The number of rows is not the same."; + + Eigen::VectorXd columnErrors(lhs.cols()); + for (int c = 0; c < lhs.cols(); ++c) { + columnErrors[c] = cosine_distance(lhs.col(c), rhs.col(c)); + } + return columnErrors.sum() / lhs.cols(); +} + +double distanceBetweenRingKeys(const std::vector &rhs, + const std::vector &lhs) { + double minDistance = std::numeric_limits::max(); + for (int rightRingKey = 0; rightRingKey < rhs.size(); ++rightRingKey) { + for (int leftRingKey = 0; leftRingKey < lhs.size(); ++leftRingKey) { + double dist = cosine_distance(rhs[rightRingKey], lhs[leftRingKey]); + if (dist < minDistance) { + minDistance = dist; + } + } + } + return minDistance; +} + +Eigen::VectorXd computeRingKey(const Eigen::MatrixXd &grid) { + Eigen::VectorXd ringKey{grid.rows()}; + for (int r = 0; r < grid.rows(); ++r) { + Eigen::VectorXd row = grid.row(r); + ringKey[r] = (row.array() > 0).count() / static_cast(grid.cols()); + } + return ringKey; +} + +ScanContext::ScanContext(const std::string &filename) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + image_sequence_localizer::ScanContext scanContextProto; + std::fstream input(filename, std::ios::in | std::ios::binary); + if (!scanContextProto.ParseFromIstream(&input)) { + LOG(FATAL) << "Failed to parse feature_proto file: " << filename; + } + + for (int shift = 0; shift < scanContextProto.grids_size(); ++shift) { + const int rows = scanContextProto.grids(shift).rows(); + const int cols = scanContextProto.grids(shift).cols(); + + Eigen::MatrixXd grid(rows, cols); + std::vector row; + for (int idx = 0; idx < scanContextProto.grids(shift).values_size(); + ++idx) { + int r = idx / cols; + int c = idx - r * cols; + grid(r, c) = scanContextProto.grids(shift).values(idx); + } + grids.push_back(grid); + } + type = "ScanContext"; +} + +double ScanContext::computeSimilarityScore(const iFeature &rhs) const { + CHECK(this->type == rhs.type) << "Features are not the same type"; + const ScanContext &otherFeature = static_cast(rhs); + std::vector distances; + distances.reserve(this->grids.size() * otherFeature.grids.size()); + + for (const Eigen::MatrixXd &grid : this->grids) { + for (const Eigen::MatrixXd &otherGrid : otherFeature.grids) { + // TODO(olga): Rotation invariance can go here. fliping matrix by columns. + distances.push_back(computeDistanceBetweenGrids(grid, otherGrid)); + } + } + return *std::min_element(distances.begin(), distances.end()); +} + +double ScanContext::score2cost(double score) const { return score; } + +std::vector ScanContext::computeRingKeys() const { + std::vector ringKeys; + for (int shift = 0; shift < grids.size(); ++shift) { + ringKeys.push_back(computeRingKey(grids[shift])); + } + return ringKeys; +} + +} // namespace localization::features diff --git a/src/localization/features/scan_context.h b/src/localization/features/scan_context.h new file mode 100644 index 0000000..310ef46 --- /dev/null +++ b/src/localization/features/scan_context.h @@ -0,0 +1,31 @@ +#ifndef LOCALIZATION_FEATURES_SCAN_CONTEXT_H_ +#define LOCALIZATION_FEATURES_SCAN_CONTEXT_H_ + +#include "features/ifeature.h" + +#include + +#include + +namespace localization::features { + +double computeDistanceBetweenGrids(const Eigen::MatrixXd &lhs, + const Eigen::MatrixXd &rhs); +Eigen::VectorXd computeRingKey(const Eigen::MatrixXd &grid); + +double distanceBetweenRingKeys(const std::vector &rhs, + const std::vector &lhs); + +class ScanContext : public iFeature { +public: + explicit ScanContext(const std::string &filename); + double computeSimilarityScore(const iFeature &rhs) const override; + double score2cost(double score) const override; + std::vector computeRingKeys() const; + + std::vector grids; + // TODO (olga): Probably can store RingKeys +}; +} // namespace localization::features + +#endif // LOCALIZATION_FEATURES_SCAN_CONTEXT_H_ \ No newline at end of file diff --git a/src/localization/online_localizer/online_localizer.cpp b/src/localization/online_localizer/online_localizer.cpp index 36d9b4c..6141cb8 100644 --- a/src/localization/online_localizer/online_localizer.cpp +++ b/src/localization/online_localizer/online_localizer.cpp @@ -91,27 +91,9 @@ Matches OnlineLocalizer::findMatchesTill(int queryId) { return getCurrentPath(); } -void OnlineLocalizer::writeOutExpanded(const std::string &filename) const { - image_sequence_localizer::Patch patch; - for (const auto &node : expandedRecently_) { - image_sequence_localizer::Patch::Element *element = patch.add_elements(); - element->set_row(node.quId); - element->set_col(node.refId); - element->set_similarity_value(node.idvCost); - } - std::fstream out(filename, - std::ios::out | std::ios::trunc | std::ios::binary); - if (!patch.SerializeToOstream(&out)) { - LOG(ERROR) << "Couldn't open the file" << filename; - return; - } - out.close(); - LOG(INFO) << "Wrote patch " << filename; -} - // frontier picking up routine void OnlineLocalizer::matchImage(int quId) { - expandedRecently_.clear(); + // expandedRecently_.clear(); std::unordered_set children; if (needReloc_) { @@ -152,7 +134,7 @@ void OnlineLocalizer::matchImage(int quId) { } } for (const Node &n : children) { - expandedRecently_.insert(n); + expandedNodes_.insert(n); } } @@ -379,7 +361,7 @@ void OnlineLocalizer::visualize() const { return; } // _vis->drawFrontier(frontier_); - _vis->drawExpansion(expandedRecently_); + _vis->drawExpansion(expandedNodes_); std::vector path = getCurrentPath(); std::reverse(path.begin(), path.end()); _vis->drawPath(path); diff --git a/src/localization/online_localizer/online_localizer.h b/src/localization/online_localizer/online_localizer.h index 9a95d42..70d2913 100644 --- a/src/localization/online_localizer/online_localizer.h +++ b/src/localization/online_localizer/online_localizer.h @@ -50,6 +50,7 @@ class OnlineLocalizer { ~OnlineLocalizer() {} Matches findMatchesTill(int queryId); + const NodeSet &showExpandedNodes() const { return expandedNodes_; } void writeOutExpanded(const std::string &filename) const; protected: @@ -85,7 +86,7 @@ class OnlineLocalizer { successor_manager::SuccessorManager *successorManager_ = nullptr; iLocVisualizer::Ptr _vis = nullptr; - NodeSet expandedRecently_; + NodeSet expandedNodes_; }; } // namespace localization::online_localizer diff --git a/src/localization/relocalizers/CMakeLists.txt b/src/localization/relocalizers/CMakeLists.txt index 8930d31..f9b5e4a 100644 --- a/src/localization/relocalizers/CMakeLists.txt +++ b/src/localization/relocalizers/CMakeLists.txt @@ -12,3 +12,15 @@ target_link_libraries(default_relocalizer cxx_flags glog::glog ) + +find_package (Eigen3 3.3 REQUIRED NO_MODULE) + +add_library(scan_context_relocalizer scan_context_relocalizer.cpp) +target_link_libraries(scan_context_relocalizer + online_database + scan_context + Eigen3::Eigen + ${OpenCV_LIBS} + cxx_flags + glog::glog +) diff --git a/src/localization/relocalizers/lsh_cv_hashing.cpp b/src/localization/relocalizers/lsh_cv_hashing.cpp index 4b5da87..fa09c0c 100644 --- a/src/localization/relocalizers/lsh_cv_hashing.cpp +++ b/src/localization/relocalizers/lsh_cv_hashing.cpp @@ -94,13 +94,13 @@ std::vector LshCvHashing::getCandidates(int quId) { Timer timer; timer.start(); - std::vector candidates; - candidates = hashFeature(feature); + std::vector candidates = hashFeature(feature); timer.stop(); LOG(INFO) << "Hash retrieval time"; timer.print_elapsed_time(TimeExt::MSec); LOG(INFO) << "Candidates size: " << candidates.size(); + return candidates; } } // namespace localization::relocalizers \ No newline at end of file diff --git a/src/localization/relocalizers/scan_context_relocalizer.cpp b/src/localization/relocalizers/scan_context_relocalizer.cpp new file mode 100644 index 0000000..cf2efc5 --- /dev/null +++ b/src/localization/relocalizers/scan_context_relocalizer.cpp @@ -0,0 +1,108 @@ +#include "relocalizers/scan_context_relocalizer.h" +#include "database/list_dir.h" +#include "features/feature_factory.h" +#include "features/scan_context.h" + +#include +#include + +#include +#include +#include + +namespace localization::relocalizers { + +constexpr int kMaxCandidateNum = 5; + +ScanContextRelocalizer::ScanContextRelocalizer( + const std::string &features_dir, database::OnlineDatabase *database) { + CHECK(!features_dir.empty()) << "Feature directory is not set"; + CHECK(database) << "Online database is not set"; + + database_ = database; + + std::vector featureFiles = + database::listProtoDir(features_dir, ".Feature"); + + LOG(INFO) << "I have found " << featureFiles.size() << " feature files"; + + ringKeys_.reserve(featureFiles.size()); + for (const auto &file : featureFiles) { + ringKeys_.push_back(features::ScanContext(file).computeRingKeys()); + } + + // indexParam_ = new cv::flann::LshIndexParams(25, 20, 2); + indexParam_ = new cv::flann::KDTreeIndexParams(25); + matcherPtr_ = + cv::Ptr(new cv::FlannBasedMatcher(indexParam_)); + + LOG(INFO) << "Starting adding features for hashing"; + cv::Mat features; + // Turn ringKeys to CV::Mat objects + for (const auto &featureRingKeys : ringKeys_) { + for (const auto &shift : featureRingKeys) { + cv::Mat shiftRingKeyMat; + eigen2cv(shift, shiftRingKeyMat); + shiftRingKeyMat.convertTo(shiftRingKeyMat, CV_32F); + features.push_back(shiftRingKeyMat.t()); + } + } + matcherPtr_->add(features); + + LOG(INFO) << "Started training..."; + matcherPtr_->train(); + LOG(INFO) << "Finished traning."; +} + +std::vector ScanContextRelocalizer::getBruteForceCandidates( + const features::ScanContext &queryScanContext) const { + std::vector> distancesPerFeature; + distancesPerFeature.reserve(ringKeys_.size()); + for (int refId = 0; refId < ringKeys_.size(); ++refId) { + const auto dist = features::distanceBetweenRingKeys( + queryScanContext.computeRingKeys(), ringKeys_[refId]); + distancesPerFeature.push_back({refId, dist}); + } + std::sort(distancesPerFeature.begin(), distancesPerFeature.end(), + [](const auto &left, const auto &right) { + return left.second < right.second; + }); + + const int kNumOfCandidates = 5; + std::vector candidates; + for (int k = 0; + k < std::min(kNumOfCandidates, (int)distancesPerFeature.size()); ++k) { + candidates.push_back(distancesPerFeature[k].first); + } + return candidates; +} + +std::vector ScanContextRelocalizer::getCandidates(int quId) { + const features::iFeature &queryFeature = database_->getQueryFeature(quId); + CHECK(queryFeature.type == "ScanContext") << "Not a ScanContext feature"; + const features::ScanContext &queryScanContext = + static_cast(queryFeature); + const auto queryRingKeys = queryScanContext.computeRingKeys(); + + // TODO (olga): This can be done without the for loop. + std::set matchedIds; + for (const auto &shift : queryRingKeys) { + std::vector> matches; + cv::Mat shiftRingKeyMat; + eigen2cv(shift, shiftRingKeyMat); + shiftRingKeyMat.convertTo(shiftRingKeyMat, CV_32F); + matcherPtr_->knnMatch(shiftRingKeyMat.t(), matches, kMaxCandidateNum); + + for (int k = 0; k < matches.size(); ++k) { + for (const auto &match : matches[k]) { + matchedIds.insert( + match.trainIdx / + queryRingKeys + .size()); // This only considers that the there are the same + // number of shifts in the trained dataset + } + } + } + return {matchedIds.begin(), matchedIds.end()}; +} +} // namespace localization::relocalizers \ No newline at end of file diff --git a/src/localization/relocalizers/scan_context_relocalizer.h b/src/localization/relocalizers/scan_context_relocalizer.h new file mode 100644 index 0000000..fb7919d --- /dev/null +++ b/src/localization/relocalizers/scan_context_relocalizer.h @@ -0,0 +1,31 @@ +#ifndef LOCALIZATION_RELOCALIZERS_SCAN_CONTEXT_RELOCALIZER_H_ +#define LOCALIZATION_RELOCALIZERS_SCAN_CONTEXT_RELOCALIZER_H_ + +#include "database/online_database.h" +#include "features/scan_context.h" +#include "relocalizers/irelocalizer.h" + +#include "opencv2/features2d/features2d.hpp" +#include + +#include +#include + +namespace localization::relocalizers { + +class ScanContextRelocalizer : public iRelocalizer { +public: + ScanContextRelocalizer(const std::string &features_dir, + database::OnlineDatabase *database); + std::vector getCandidates(int quId) override; + +private: + std::vector + getBruteForceCandidates(const features::ScanContext &queryScanContext) const; + std::vector> ringKeys_; + database::OnlineDatabase *database_ = nullptr; + cv::Ptr matcherPtr_; + cv::Ptr indexParam_; +}; +} // namespace localization::relocalizers +#endif // LOCALIZATION_RELOCALIZERS_SCAN_CONTEXT_RELOCALIZER_H_ \ No newline at end of file diff --git a/src/localization/successor_manager/successor_manager.cpp b/src/localization/successor_manager/successor_manager.cpp index 1ab198e..2254ee3 100644 --- a/src/localization/successor_manager/successor_manager.cpp +++ b/src/localization/successor_manager/successor_manager.cpp @@ -30,7 +30,6 @@ #include #include -#include using std::vector; namespace localization::successor_manager { @@ -166,7 +165,7 @@ SuccessorManager::getSuccessorsIfLost(const Node &node) { double succ_cost = database_->getCost(succ_qu_id, candId); succ.set(succ_qu_id, candId, succ_cost); _successors.insert(succ); - succ.print(); + // succ.print(); } } return _successors; diff --git a/src/localization/tools/config_parser/config_parser.cpp b/src/localization/tools/config_parser/config_parser.cpp index fba2395..98df588 100644 --- a/src/localization/tools/config_parser/config_parser.cpp +++ b/src/localization/tools/config_parser/config_parser.cpp @@ -31,175 +31,178 @@ using std::string; bool ConfigParser::parse(const std::string &iniFile) { - std::ifstream in(iniFile.c_str()); - if (!in) { - printf("[ERROR][ConfigParser] The file \"%s\" cannot be opened.\n", - iniFile.c_str()); - return false; - } - while (!in.eof()) { - string line; - std::getline(in, line); - if (line.empty() || line[0] == '#') { - // it should be a comment - continue; - } - std::stringstream ss(line); - while (!ss.eof()) { - string header; - ss >> header; - if (header == "path2qu") { - ss >> header; // reads "=" - ss >> path2qu; - continue; - } - - if (header == "path2ref") { - ss >> header; // reads "=" - ss >> path2ref; - continue; - } - - if (header == "querySize") { - ss >> header; // reads "=" - ss >> querySize; - } - - if (header == "nonMatchCost") { - ss >> header; // reads "=" - ss >> nonMatchCost; - continue; - } - - if (header == "expansionRate") { - ss >> header; // reads "=" - ss >> expansionRate; - continue; - } - if (header == "fanOut") { - ss >> header; // reads "=" - ss >> fanOut; - continue; - } - if (header == "bufferSize") { - ss >> header; // reads "=" - ss >> bufferSize; - continue; - } - - if (header == "path2quImg") { - ss >> header; // reads "=" - ss >> path2quImg; - continue; - } - - if (header == "path2refImg") { - ss >> header; // reads "=" - ss >> path2refImg; - continue; - } - if (header == "imgExt") { - ss >> header; // reads "=" - ss >> imgExt; - continue; - } - if (header == "costMatrix") { - ss >> header; // reads "=" - ss >> costMatrix; - continue; - } - if (header == "costOutputName") { - ss >> header; // reads "=" - ss >> costOutputName; - continue; - } - if (header == "simPlaces") { - ss >> header; // reads "=" - ss >> simPlaces; - continue; - } - } // end of line parsing - } // end of file - return true; + std::ifstream in(iniFile.c_str()); + if (!in) { + printf("[ERROR][ConfigParser] The file \"%s\" cannot be opened.\n", + iniFile.c_str()); + return false; + } + while (!in.eof()) { + string line; + std::getline(in, line); + if (line.empty() || line[0] == '#') { + // it should be a comment + continue; + } + std::stringstream ss(line); + while (!ss.eof()) { + string header; + ss >> header; + if (header == "path2qu") { + ss >> header; // reads "=" + ss >> path2qu; + continue; + } + + if (header == "path2ref") { + ss >> header; // reads "=" + ss >> path2ref; + continue; + } + + if (header == "querySize") { + ss >> header; // reads "=" + ss >> querySize; + } + + if (header == "nonMatchCost") { + ss >> header; // reads "=" + ss >> nonMatchCost; + continue; + } + + if (header == "expansionRate") { + ss >> header; // reads "=" + ss >> expansionRate; + continue; + } + if (header == "fanOut") { + ss >> header; // reads "=" + ss >> fanOut; + continue; + } + if (header == "bufferSize") { + ss >> header; // reads "=" + ss >> bufferSize; + continue; + } + + if (header == "path2quImg") { + ss >> header; // reads "=" + ss >> path2quImg; + continue; + } + + if (header == "path2refImg") { + ss >> header; // reads "=" + ss >> path2refImg; + continue; + } + if (header == "imgExt") { + ss >> header; // reads "=" + ss >> imgExt; + continue; + } + if (header == "costMatrix") { + ss >> header; // reads "=" + ss >> costMatrix; + continue; + } + if (header == "costOutputName") { + ss >> header; // reads "=" + ss >> costOutputName; + continue; + } + if (header == "simPlaces") { + ss >> header; // reads "=" + ss >> simPlaces; + continue; + } + } // end of line parsing + } // end of file + return true; } void ConfigParser::print() const { - printf("== Read parameters ==\n"); - printf("== Path2query: %s\n", path2qu.c_str()); - printf("== Path2ref: %s\n", path2ref.c_str()); - - printf("== Query size: %d\n", querySize); - printf("== NonMatchCost: %3.4f\n", nonMatchCost); - printf("== Expansion Rate: %3.4f\n", expansionRate); - printf("== FanOut: %d\n", fanOut); - - printf("== Path2query images: %s\n", path2quImg.c_str()); - printf("== Path2reference images: %s\n", path2refImg.c_str()); - printf("== Image extension: %s\n", imgExt.c_str()); - printf("== Buffer size: %d\n", bufferSize); - - printf("== CostMatrix: %s\n", costMatrix.c_str()); - printf("== costOutputName: %s\n", costOutputName.c_str()); - printf("== matchingResult: %s\n", matchingResult.c_str()); - printf("== simPlaces: %s\n", simPlaces.c_str()); + printf("== Read parameters ==\n"); + printf("== Path2query: %s\n", path2qu.c_str()); + printf("== Path2ref: %s\n", path2ref.c_str()); + + printf("== Query size: %d\n", querySize); + printf("== NonMatchCost: %3.4f\n", nonMatchCost); + printf("== Expansion Rate: %3.4f\n", expansionRate); + printf("== FanOut: %d\n", fanOut); + + printf("== Path2query images: %s\n", path2quImg.c_str()); + printf("== Path2reference images: %s\n", path2refImg.c_str()); + printf("== Image extension: %s\n", imgExt.c_str()); + printf("== Buffer size: %d\n", bufferSize); + + printf("== CostMatrix: %s\n", costMatrix.c_str()); + printf("== costOutputName: %s\n", costOutputName.c_str()); + printf("== matchingResult: %s\n", matchingResult.c_str()); + printf("== simPlaces: %s\n", simPlaces.c_str()); } bool ConfigParser::parseYaml(const std::string &yamlFile) { - YAML::Node config; - try { - config = YAML::LoadFile(yamlFile.c_str()); - } catch (...) { - printf("[ERROR][ConfigParser] File %s cannot be opened\n", - yamlFile.c_str()); - return false; - } - if (config["path2ref"]) { - path2ref = config["path2ref"].as(); - } - if (config["path2qu"]) { - path2qu = config["path2qu"].as(); - } - if (config["querySize"]) { - querySize = config["querySize"].as(); - } - if (config["fanOut"]) { - fanOut = config["fanOut"].as(); - } - if (config["nonMatchCost"]) { - nonMatchCost = config["nonMatchCost"].as(); - } - if (config["expansionRate"]) { - expansionRate = config["expansionRate"].as(); - } - if (config["path2quImg"]) { - path2quImg = config["path2quImg"].as(); - } - - if (config["path2refImg"]) { - path2refImg = config["path2refImg"].as(); - } - if (config["imgExt"]) { - imgExt = config["imgExt"].as(); - } - if (config["bufferSize"]) { - bufferSize = config["bufferSize"].as(); - } - if (config["costMatrix"]) { - costMatrix = config["costMatrix"].as(); - } - if (config["costOutputName"]) { - costOutputName = config["costOutputName"].as(); - } - if (config["simPlaces"]) { - simPlaces = config["simPlaces"].as(); - } - - if (config["hashTable"]) { - hashTable = config["hashTable"].as(); - } - if (config["matchingResult"]) { - matchingResult = config["matchingResult"].as(); - } - - return true; + YAML::Node config; + try { + config = YAML::LoadFile(yamlFile.c_str()); + } catch (...) { + printf("[ERROR][ConfigParser] File %s cannot be opened\n", + yamlFile.c_str()); + return false; + } + if (config["path2ref"]) { + path2ref = config["path2ref"].as(); + } + if (config["path2qu"]) { + path2qu = config["path2qu"].as(); + } + if (config["querySize"]) { + querySize = config["querySize"].as(); + } + if (config["fanOut"]) { + fanOut = config["fanOut"].as(); + } + if (config["nonMatchCost"]) { + nonMatchCost = config["nonMatchCost"].as(); + } + if (config["expansionRate"]) { + expansionRate = config["expansionRate"].as(); + } + if (config["path2quImg"]) { + path2quImg = config["path2quImg"].as(); + } + + if (config["path2refImg"]) { + path2refImg = config["path2refImg"].as(); + } + if (config["imgExt"]) { + imgExt = config["imgExt"].as(); + } + if (config["bufferSize"]) { + bufferSize = config["bufferSize"].as(); + } + if (config["costMatrix"]) { + costMatrix = config["costMatrix"].as(); + } + if (config["costOutputName"]) { + costOutputName = config["costOutputName"].as(); + } + if (config["simPlaces"]) { + simPlaces = config["simPlaces"].as(); + } + + if (config["hashTable"]) { + hashTable = config["hashTable"].as(); + } + if (config["matchingResult"]) { + matchingResult = config["matchingResult"].as(); + } + if (config["expandedNodesFile"]) { + expandedNodesFile = config["expandedNodesFile"].as(); + } + + return true; } diff --git a/src/localization/tools/config_parser/config_parser.h b/src/localization/tools/config_parser/config_parser.h index b089ea6..fb77a64 100644 --- a/src/localization/tools/config_parser/config_parser.h +++ b/src/localization/tools/config_parser/config_parser.h @@ -30,28 +30,29 @@ * @brief Class for storing the configuration parameters. */ class ConfigParser { - public: - ConfigParser() {} - bool parse(const std::string &iniFile); - bool parseYaml(const std::string &yamlFile); - void print() const; +public: + ConfigParser() {} + bool parse(const std::string &iniFile); + bool parseYaml(const std::string &yamlFile); + void print() const; - std::string path2qu = ""; - std::string path2ref = ""; - std::string path2quImg = ""; - std::string path2refImg = ""; - std::string imgExt = ""; - std::string costMatrix = ""; - std::string costOutputName = ""; - std::string simPlaces = ""; - std::string hashTable = ""; - std::string matchingResult = "matches.MatchingResult.pb"; + std::string path2qu = ""; + std::string path2ref = ""; + std::string path2quImg = ""; + std::string path2refImg = ""; + std::string imgExt = ""; + std::string costMatrix = ""; + std::string costOutputName = ""; + std::string simPlaces = ""; + std::string hashTable = ""; + std::string matchingResult = "matches.MatchingResult.pb"; + std::string expandedNodesFile = ""; - int querySize = -1; - int fanOut = -1; - int bufferSize = -1; - double nonMatchCost = -1.0; - double expansionRate = -1.0; + int querySize = -1; + int fanOut = -1; + int bufferSize = -1; + double nonMatchCost = -1.0; + double expansionRate = -1.0; }; /*! \var std::string ConfigParser::path2qu @@ -111,4 +112,4 @@ class ConfigParser { typically be selected from 0.5 - 0.7. */ -#endif // SRC_TOOLS_CONFIG_PARSER_CONFIG_PARSER_H_ +#endif // SRC_TOOLS_CONFIG_PARSER_CONFIG_PARSER_H_ diff --git a/src/localization_protos.proto b/src/localization_protos.proto index 0bc3152..fbf4427 100644 --- a/src/localization_protos.proto +++ b/src/localization_protos.proto @@ -26,6 +26,16 @@ message Feature { optional string type = 3; } +message ScanContext{ + message Grid{ + repeated double values = 1; + optional int32 cols = 2; + optional int32 rows = 3; + + } + repeated Grid grids = 1; +} + message Patch { message Element { optional int32 row = 1; diff --git a/src/python/convert_numpy_to_scan_context.py b/src/python/convert_numpy_to_scan_context.py new file mode 100644 index 0000000..fb18939 --- /dev/null +++ b/src/python/convert_numpy_to_scan_context.py @@ -0,0 +1,95 @@ +import numpy as np +from pathlib import Path +import argparse +import protos_io + +import protos.localization_protos_pb2 as loc_protos + + +def convert_to_protos(features): + """Converts the numpy nd array to the features protos. + + Args: + features (numpy.ndarray): feature vectors to be converted. + Expected size NxSxRxC, where N is + number of features and S is number + of shifts, R number of rows and C number of columns. + Returns: + [image_sequence_localizer.ScanContext]: list of feature protos. + For details check + localization_protos.proto file. + """ + protos = [] + print("Number of features", features.shape) + assert ( + len(features.shape) == 4 + ), "Expected a N x shifts x rows x cols dimensional matrix." + for feature in features: + scanContext_proto = loc_protos.ScanContext() + for shift in feature: + grid_proto = loc_protos.ScanContext.Grid() + grid_proto.rows = shift.shape[0] + grid_proto.cols = shift.shape[1] + for r in range(grid_proto.rows): + for c in range(grid_proto.cols): + grid_proto.values.extend([shift[r][c]]) + scanContext_proto.grids.extend([grid_proto]) + protos.append(scanContext_proto) + return protos + + +def save_protos_to_files(folder, protos, type, prefix): + if folder.exists(): + print("WARNING: the folder exists. Potentially overwritting the files.") + else: + folder.mkdir() + for idx, proto in enumerate(protos): + feature_idx = "{0:07d}".format(idx) + proto_file = "{prefix}_{feature_idx}.{type}.Feature.pb".format( + prefix=prefix, feature_idx=feature_idx, type=type + ) + protos_io.write_feature(folder / proto_file, proto) + print("Feature", idx, "was written to ", folder / proto_file) + return + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--filename", + required=True, + type=Path, + help="Path to file that contains features that can be loaded with np.load()", + ) + parser.add_argument( + "--feature_type", + required=True, + type=str, + help="Type of the features, e.g. ScanContext", + ) + parser.add_argument( + "--output_folder", required=True, type=Path, help="Path to output directory" + ) + parser.add_argument( + "--output_file_prefix", + required=False, + type=str, + default="scanContext", + help="Prefix for every feature file that will be generated", + ) + + args = parser.parse_args() + try: + features = np.load(args.filename) + except: + print("ERROR: Could not read features from", args.filename) + return + print("There are: ", len(features)) + protos = convert_to_protos(features) + save_protos_to_files( + args.output_folder, protos, args.feature_type, args.output_file_prefix + ) + + +if __name__ == "__main__": + main() diff --git a/src/python/protos_io.py b/src/python/protos_io.py index 9643194..60a5377 100644 --- a/src/python/protos_io.py +++ b/src/python/protos_io.py @@ -50,15 +50,10 @@ def read_matching_result(filename): return result_proto -def read_expanded_mask(expanded_patches_dir): - patch_files = list(expanded_patches_dir.glob("*.Patch.pb")) - patch_files.sort() - - mask = [] - for patch_file in patch_files: - f = open(patch_file, "rb") - patch_proto = loc_protos.Patch() - patch_proto.ParseFromString(f.read()) - f.close() - mask.extend(patch_proto.elements) - return mask +def read_expanded_mask(expanded_patches_file): + + f = open(expanded_patches_file, "rb") + patch_proto = loc_protos.Patch() + patch_proto.ParseFromString(f.read()) + f.close() + return patch_proto diff --git a/src/python/run_scan_context_matching.py b/src/python/run_scan_context_matching.py new file mode 100644 index 0000000..266a62e --- /dev/null +++ b/src/python/run_scan_context_matching.py @@ -0,0 +1,123 @@ +import argparse +from pathlib import Path +import os +import yaml + + +def parseParams(): + parser = argparse.ArgumentParser(description="Run image matching.") + parser.add_argument( + "--query_features", + type=Path, + required=True, + help="Path to the directory with images in .jpg or .png format", + ) + parser.add_argument( + "--reference_features", + type=Path, + required=True, + help="Path to the directory with images in .jpg or .png format", + ) + + parser.add_argument( + "--cost_matrix", type=Path, required=True, help="Path to precompute cost matrix" + ) + parser.add_argument( + "--dataset_name", + type=str, + required=True, + help="The name of the dataset.", + ) + parser.add_argument( + "--output_dir", + type=Path, + required=True, + help="Path to output directory to store results.", + ) + return parser.parse_args() + + +def setDictParam(args, query_features_dir, reference_features_dir, cost_matrix_dir): + params = dict() + params["path2qu"] = str(query_features_dir) + params["path2ref"] = str(reference_features_dir) + params["costMatrix"] = str(cost_matrix_dir) + params["matchingResult"] = str( + args.output_dir / (args.dataset_name + ".MatchingResult.pb") + ) + params["matchingResultImage"] = str( + args.output_dir / (args.dataset_name + "_result.png") + ) + params["expandedNodesFile"] = str( + args.output_dir / (args.dataset_name + "_expanded_nodes.Patch.pb") + ) + params["expansionRate"] = 0.7 + params["fanOut"] = 10 + params["nonMatchCost"] = 0.1 + params["bufferSize"] = 100 + + queriesNum = len(list(query_features_dir.glob("*Feature.pb"))) + params["querySize"] = queriesNum + return params + + +def runMatching(config_yaml_file): + binary = "../../build/src/apps/cost_matrix_based_matching/scan_context_matching" + command = binary + " " + str(config_yaml_file) + print("Calling:", command) + os.system(command) + + +def runResultVisualization(config): + params = "--cost_matrix {cost_matrix} ".format(cost_matrix=config["costMatrix"]) + params += "--matching_result {matching_result} ".format( + matching_result=config["matchingResult"] + ) + params += "--image_name {image_name} ".format( + image_name=config["matchingResultImage"] + ) + + # params += "--expanded_patches_file {file} ".format(file=config["expandedNodesFile"]) + command = "python visualize_localization_result.py " + params + print("Calling:", command) + os.system(command) + + +def runMatchingResultVisualization(config, output_dir): + params = "--matching_result {matching_result} ".format( + matching_result=config["matchingResult"] + ) + params += "--query_images {query_images} ".format( + query_images=config["path2query_images"] + ) + params += "--reference_images {ref_images} ".format( + ref_images=config["path2ref_images"] + ) + params += "--output_dir {output_dir}/matched_images ".format(output_dir=output_dir) + + command = "python visualize_matching_result.py " + params + print("Calling:", command) + os.system(command) + + +def main(): + args = parseParams() + + if args.output_dir.exists(): + print("WARNING: output_dir exists. Overwritting the results") + else: + args.output_dir.mkdir() + + yaml_config = setDictParam( + args, args.query_features, args.reference_features, args.cost_matrix + ) + + yaml_config_file = args.output_dir / (args.dataset_name + "_config.yml") + with open(yaml_config_file, "w") as file: + yaml.dump(yaml_config, file) + runMatching(yaml_config_file) + runResultVisualization(yaml_config) + + +if __name__ == "__main__": + main() diff --git a/src/python/visualize_localization_result.py b/src/python/visualize_localization_result.py index 48c9e0f..587fc25 100644 --- a/src/python/visualize_localization_result.py +++ b/src/python/visualize_localization_result.py @@ -14,7 +14,7 @@ def create_combined_image(matching_result, cost_matrix, expanded_mask=None): # Add expanded nodes. if expanded_mask: - for element in expanded_mask: + for element in expanded_mask.elements: rgb_costs[element.row, element.col] = [0, 1, 0] # Add path with color. Red - real nodes, Blue - hidden nodes. @@ -45,10 +45,10 @@ def main(): help="Path to the matching result .MatchingResult.pb file", ) parser.add_argument( - "--expanded_patches_dir", + "--expanded_patches_file", required=False, type=Path, - help="Path to directory with expanded nodes files of type .Patch.pb", + help="A file with expanded nodes files of type .Patch.pb", ) parser.add_argument( "--image_name", @@ -76,8 +76,8 @@ def main(): matching_result = protos_io.read_matching_result(args.matching_result) - if args.expanded_patches_dir: - expanded_mask = protos_io.read_expanded_mask(args.expanded_patches_dir) + if args.expanded_patches_file: + expanded_mask = protos_io.read_expanded_mask(args.expanded_patches_file) else: expanded_mask = None diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 6273007..e3e9d4f 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -8,6 +8,17 @@ set(TESTNAME test_${PROJECT_NAME}) # cxx_flags # ) +add_executable(${TESTNAME}_scan_context + scan_context_test.cpp +) +target_link_libraries(${TESTNAME}_scan_context + scan_context + protos + gtest + gtest_main + cxx_flags +) + add_executable(${TESTNAME} cost_matrix_test.cpp database_test.cpp @@ -42,3 +53,4 @@ target_link_libraries(${TESTNAME}_successor_manager gtest_discover_tests(${TESTNAME}) gtest_discover_tests(${TESTNAME}_successor_manager) +gtest_discover_tests(${TESTNAME}_scan_context) diff --git a/src/test/cost_matrix_test.cpp b/src/test/cost_matrix_test.cpp index 303095a..2893b02 100644 --- a/src/test/cost_matrix_test.cpp +++ b/src/test/cost_matrix_test.cpp @@ -60,6 +60,7 @@ TEST_F(CostMatrixTest, FailedToConstruct) { TEST_F(CostMatrixTest, at) { auto costMatrix = localization::database::CostMatrix(costMatrixFile); + costMatrix.inverseCosts(false); localization::database::CostMatrix::Matrix expectedMatrix = this->costMatrixValues; EXPECT_EQ(costMatrix.rows(), expectedMatrix.size()); @@ -88,6 +89,7 @@ TEST(CostMatrixComputation, createCostMatrixFromFeatures) { auto costMatrix = localization::database::CostMatrix( tmp_dir, tmp_dir, localization::features::Cnn_Feature); + costMatrix.inverseCosts(false); for (int r = 0; r < costMatrix.rows(); ++r) { for (int c = 0; c < costMatrix.cols(); ++c) { diff --git a/src/test/scan_context_test.cpp b/src/test/scan_context_test.cpp new file mode 100644 index 0000000..2bc7b66 --- /dev/null +++ b/src/test/scan_context_test.cpp @@ -0,0 +1,57 @@ +/* By O. Vysotska in 2023 */ + +#include "features/scan_context.h" + +#include + +#include "gtest/gtest.h" +#include +#include + +namespace test { + +TEST(ComputeDistanceBetweenGrids, validInput) { + Eigen::MatrixXd lhs(2, 3); + lhs << 1, 0, 0, 1, 1, 0; + Eigen::MatrixXd rhs(2, 3); + rhs << 1, 1, 1, 1, 0, 0; + + EXPECT_NEAR(localization::features::computeDistanceBetweenGrids(lhs, rhs), + 1.0 / 3, 1e-02); +} + +TEST(ComputeDistanceBetweenGrids, invalidInput) { + Eigen::MatrixXd rhs(2, 3); + rhs << 1, 1, 1, 1, 0, 0; + + EXPECT_DEATH(localization::features::computeDistanceBetweenGrids({}, rhs), + "Invalid grid size."); + + Eigen::MatrixXd lhs(1, 3); + lhs << 1, 1, 1; + + EXPECT_DEATH(localization::features::computeDistanceBetweenGrids(lhs, rhs), + "The number of rows is not the same."); + + Eigen::MatrixXd lhs_wrong(2, 2); + lhs_wrong << 1, 1, 1, 1; + + EXPECT_DEATH( + localization::features::computeDistanceBetweenGrids(lhs_wrong, rhs), + "The number of columns is not the same."); +} + +TEST(ComputeRingKey, validInput) { + Eigen::MatrixXd grid(3, 4); + grid << 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0; + const Eigen::VectorXd ringKey = localization::features::computeRingKey(grid); + ASSERT_EQ(ringKey.size(), 3); + EXPECT_NEAR(ringKey[0], 0.5, 1e-02); + EXPECT_NEAR(ringKey[1], 0.25, 1e-02); + EXPECT_NEAR(ringKey[2], 0.0, 1e-02); +} + +// TODO(olga): Add tests for +// distanceBetweenRingKeys + +} // namespace test \ No newline at end of file