From 4246b8b75f53e9f2bbe392471b710762f972d918 Mon Sep 17 00:00:00 2001 From: RT_Enzyme Date: Sun, 22 Sep 2024 11:50:39 +0000 Subject: [PATCH 01/12] compilation execution framework. --- deps/buildit | 1 + src/BuildCypherLib.cmake | 9 + .../experimental/data_type/field_data.h | 242 ++++++++++++++++++ src/cypher/experimental/data_type/record.h | 80 ++++++ src/cypher/experimental/expressions/cexpr.cpp | 2 + src/cypher/experimental/expressions/cexpr.h | 229 +++++++++++++++++ .../expressions/kernal/binary.cpp | 181 +++++++++++++ toolkits/CMakeLists.txt | 14 + toolkits/lgraph_compilation.cpp | 54 ++++ 9 files changed, 812 insertions(+) create mode 160000 deps/buildit create mode 100644 src/cypher/experimental/data_type/field_data.h create mode 100644 src/cypher/experimental/data_type/record.h create mode 100644 src/cypher/experimental/expressions/cexpr.cpp create mode 100644 src/cypher/experimental/expressions/cexpr.h create mode 100644 src/cypher/experimental/expressions/kernal/binary.cpp create mode 100644 toolkits/lgraph_compilation.cpp diff --git a/deps/buildit b/deps/buildit new file mode 160000 index 0000000000..734725213b --- /dev/null +++ b/deps/buildit @@ -0,0 +1 @@ +Subproject commit 734725213b28694524ae06a8ef958f5444837dae diff --git a/src/BuildCypherLib.cmake b/src/BuildCypherLib.cmake index 9928aed2a1..90e47c1178 100644 --- a/src/BuildCypherLib.cmake +++ b/src/BuildCypherLib.cmake @@ -56,6 +56,9 @@ set(LGRAPH_CYPHER_SRC # find cypher/ -name "*.cpp" | sort cypher/execution_plan/ops/op_traversal.cpp cypher/execution_plan/ops/op_gql_remove.cpp cypher/execution_plan/scheduler.cpp + cypher/experimental/data_type/field_data.h + cypher/experimental/expressions/cexpr.cpp + cypher/experimental/expressions/kernal/binary.cpp cypher/filter/filter.cpp cypher/filter/iterator.cpp cypher/graph/graph.cpp @@ -88,9 +91,15 @@ target_include_directories(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC ${ANTLR4_INCLUDE_DIR} ${CMAKE_CURRENT_LIST_DIR}/cypher) +include_directories(${CMAKE_SOURCE_DIR}/deps/buildit/include) + +target_link_directories(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC + ${CMAKE_SOURCE_DIR}/deps/buildit/lib) + target_link_libraries(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC ${ANTRL4_LIBRARY} geax_isogql + ${CMAKE_SOURCE_DIR}/deps/buildit/build/libbuildit.a lgraph) target_link_libraries(${TARGET_LGRAPH_CYPHER_LIB} PRIVATE diff --git a/src/cypher/experimental/data_type/field_data.h b/src/cypher/experimental/data_type/field_data.h new file mode 100644 index 0000000000..ddf4c19589 --- /dev/null +++ b/src/cypher/experimental/data_type/field_data.h @@ -0,0 +1,242 @@ + +#pragma once +#include +#include +#include +#include +#include "core/data_type.h" +#include "cypher/cypher_types.h" +#include "cypher/cypher_exception.h" + +using builder::static_var; +using builder::dyn_var; +using lgraph::FieldType; + +namespace cypher { +namespace compilation { +#define GET_CONSTANT(type) \ + if (is_constant_) { \ + return std::get>(constant_); \ + } else { \ + return std::get>(dyn_); \ + } + +struct CScalarData { + static constexpr const char* type_name = "CScalarData"; + std::variant< + std::monostate, // Represent the null state + static_var, + static_var, + static_var, + static_var, + static_var + > constant_; + + std::variant< + std::monostate, // Represent the null state + dyn_var, + dyn_var, + dyn_var, + dyn_var, + dyn_var + > dyn_; + + bool is_constant_; + // for string type + lgraph::FieldType type_; + + CScalarData() { + type_ = lgraph_api::FieldType::NUL; + is_constant_ = false; + } + + CScalarData(CScalarData &&data) + : constant_(std::move(data.constant_)), dyn_(std::move(data.dyn_)), + is_constant_(data.is_constant_), type_(data.type_) {} + + CScalarData(const CScalarData& other) + : constant_(other.constant_), dyn_(other.dyn_), + is_constant_(other.is_constant_), type_(other.type_) {} + + CScalarData(const lgraph::FieldData& other, bool is_constant = false) { + type_ = other.type; + is_constant_ = is_constant; + if (is_constant_) { + switch (other.type) { + case lgraph::FieldType::NUL: + constant_.emplace(); + break; + case lgraph::FieldType::INT64: + constant_.emplace>((long)other.integer()); + break; + default: + CYPHER_TODO(); + } + } else { + switch (other.type) { + case lgraph::FieldType::NUL: + dyn_.emplace(); + break; + case lgraph::FieldType::INT64: + dyn_.emplace>((long)other.integer()); + break; + default: + CYPHER_TODO(); + } + + } + } + + explicit CScalarData(long integer, bool is_constant = true) { + is_constant_ = is_constant; + if (is_constant) { + constant_.emplace>(integer); + } else { + dyn_.emplace>(integer); + } + type_ = lgraph::FieldType::INT64; + } + + explicit CScalarData(const static_var &integer) + : constant_(integer), is_constant_(true), type_(FieldType::INT64) {} + + explicit CScalarData(const dyn_var &integer) + : dyn_(integer), is_constant_(false), type_(FieldType::INT64) {} + + explicit CScalarData(dyn_var&& integer) + : dyn_(std::move(integer)), is_constant_(false), type_(FieldType::INT64) {} + + inline dyn_var integer() const { + switch (type_) { + case FieldType::NUL: + case FieldType::BOOL: + throw std::bad_cast(); + case FieldType::INT8: GET_CONSTANT(short) + case FieldType::INT16: GET_CONSTANT(short) + case FieldType::INT32: GET_CONSTANT(int) + case FieldType::INT64: GET_CONSTANT(long) + case FieldType::FLOAT: + case FieldType::DOUBLE: + case FieldType::DATE: + case FieldType::DATETIME: + case FieldType::STRING: + case FieldType::BLOB: + case FieldType::POINT: + case FieldType::LINESTRING: + case FieldType::POLYGON: + case FieldType::SPATIAL: + case FieldType::FLOAT_VECTOR: + throw std::bad_cast(); + } + return dyn_var(0); + } + + inline dyn_var real() const { + switch (type_) { + case FieldType::NUL: + case FieldType::BOOL: + case FieldType::INT8: + case FieldType::INT16: + case FieldType::INT32: + case FieldType::INT64: + throw std::bad_cast(); + case FieldType::FLOAT: + GET_CONSTANT(float); + case FieldType::DOUBLE: + GET_CONSTANT(double); + case FieldType::DATE: + case FieldType::DATETIME: + case FieldType::STRING: + case FieldType::BLOB: + case FieldType::POINT: + case FieldType::LINESTRING: + case FieldType::POLYGON: + case FieldType::SPATIAL: + case FieldType::FLOAT_VECTOR: + throw std::bad_cast(); + } + return dyn_var(0); + } + + dyn_var Int64() const { + GET_CONSTANT(long) + } + + inline bool is_integer() const { + return type_ >= FieldType::INT8 && type_ <= FieldType::INT64; + } + + inline bool is_real() const { + return type_ == FieldType::DOUBLE || type_ == FieldType::FLOAT; + } + + bool is_null() const { return type_ == lgraph::FieldType::NUL; } + + bool is_string() const { return type_ == lgraph::FieldType::STRING; } + + CScalarData& operator=(CScalarData&& other) noexcept { + if (this != &other) { + constant_ = std::move(other.constant_); + dyn_ = std::move(other.dyn_); + type_ = std::move(other.type_); + is_constant_ = std::move(other.is_constant_); + } + return *this; + } + + CScalarData& operator=(const CScalarData& other) { + if (this != &other) { + constant_ = other.constant_; + dyn_ = other.dyn_; + type_ = other.type_; + is_constant_ = other.is_constant_; + } + return *this; + } + + CScalarData operator+(const CScalarData& other) const; +}; + +struct CFieldData { + enum FieldType { SCALAR, ARRAY, MAP} type; + + CScalarData scalar; + std::vector* array = nullptr; + std::unordered_map* map = nullptr; + + CFieldData() : type(SCALAR) {} + + CFieldData(const CFieldData &data) : type(data.type), scalar(data.scalar) {} + + CFieldData(const CScalarData& scalar) : type(SCALAR), scalar(scalar) {} + + CFieldData(CScalarData&& scalar) : type(SCALAR), scalar(std::move(scalar)) {} + + CFieldData& operator=(const CFieldData& data) { + this->type = data.type; + this->scalar = data.scalar; + return *this; + } + + CFieldData& operator=(CFieldData&& data) { + this->type = std::move(data.type); + this->scalar = std::move(data.scalar); + return *this; + } + + explicit CFieldData(const static_var& scalar) : type(SCALAR), scalar(scalar) {} + + bool is_null() const { return type == SCALAR && scalar.is_null(); } + + bool is_string() const { return type == SCALAR && scalar.is_string(); } + + bool is_integer() const { return type == SCALAR && scalar.is_integer(); } + + bool is_real() const { return type == SCALAR && scalar.is_real(); } + + CFieldData operator+(const CFieldData& other) const; + + CFieldData operator-(const CFieldData& other) const; +}; +} // namespace compilation +} // namespace cypher \ No newline at end of file diff --git a/src/cypher/experimental/data_type/record.h b/src/cypher/experimental/data_type/record.h new file mode 100644 index 0000000000..bbce67fee4 --- /dev/null +++ b/src/cypher/experimental/data_type/record.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include "core/data_type.h" +#include "cypher/cypher_types.h" +#include "cypher/cypher_exception.h" +#include "parser/data_typedef.h" +#include "graph/node.h" +#include "graph/relationship.h" +#include "cypher/resultset/record.h" + +#include "experimental/data_type/field_data.h" +using builder::dyn_var; + +namespace cypher { + +struct SymbolTable; +class RTContext; + +namespace compilation { +struct CEntry { + compilation::CFieldData constant_; + cypher::Node* node_ = nullptr; + cypher::Relationship* relationship_ = nullptr; + + enum RecordEntryType { + UNKNOWN = 0, + CONSTANT, + NODE, + RELATIONSHIP, + VAR_LEN_RELP, + HEADER, // TODO(anyone) useless? + NODE_SNAPSHOT, + RELP_SNAPSHOT, + } type_; + + CEntry() = default; + + explicit CEntry(const cypher::Entry& entry) { + switch (entry.type) { + case cypher::Entry::CONSTANT: { + constant_ = CScalarData(entry.constant.scalar); + type_ = CONSTANT; + break; + } + case cypher::Entry::NODE: { + node_ = entry.node; + type_ = NODE; + break; + } + case cypher::Entry::RELATIONSHIP: { + relationship_ = entry.relationship; + type_ = RELATIONSHIP; + break; + } + default: + CYPHER_TODO(); + } + } + + explicit CEntry(const CFieldData &data) : constant_(data), type_(CONSTANT) {} + + explicit CEntry(CFieldData&& data) : constant_(std::move(data)), type_(CONSTANT) {} + + explicit CEntry(const CScalarData& scalar) : constant_(scalar), type_(CONSTANT) {} +}; + +struct CRecord { // Should be derived from cypher::Record + std::vector values; + + CRecord() = default; + + CRecord(const cypher::Record &record) { + for (auto& entry : record.values) { + values.emplace_back(entry); + } + } +}; +} // namespace compilaiton +} // namespace cypher \ No newline at end of file diff --git a/src/cypher/experimental/expressions/cexpr.cpp b/src/cypher/experimental/expressions/cexpr.cpp new file mode 100644 index 0000000000..78ed26da91 --- /dev/null +++ b/src/cypher/experimental/expressions/cexpr.cpp @@ -0,0 +1,2 @@ +#include "cexpr.h" + diff --git a/src/cypher/experimental/expressions/cexpr.h b/src/cypher/experimental/expressions/cexpr.h new file mode 100644 index 0000000000..395dea5f08 --- /dev/null +++ b/src/cypher/experimental/expressions/cexpr.h @@ -0,0 +1,229 @@ +#include +#include + +#include +#include +#include +#include + +#include "geax-front-end/ast/Ast.h" +#include "geax-front-end/ast/expr/AggFunc.h" +#include "cypher/arithmetic/agg_ctx.h" +#include "cypher/arithmetic/ast_agg_expr_detector.h" +#include "cypher/execution_plan/visitor/visitor.h" +#include "cypher/resultset/record.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/cypher_types.h" +#include "core/data_type.h" + +#include "experimental/data_type/field_data.h" +#include "experimental/data_type/record.h" +#include "cypher/execution_plan/runtime_context.h" + +using builder::static_var; +using builder::dyn_var; +using builder::defer_init; +using builder::with_name; + +namespace cypher { +namespace compilation { + +// struct ArithOperandNode { +// static constexpr const char* type_name = "ArithOperandNode"; +// CScalarData constant; +// struct Variadic { +// static_var alias; +// static_var alias_idx; +// static_var entity_prop; +// } variadic; +// struct Variable { +// bool hasMapFieldName; +// std::string _value_alias; +// std::string _map_field_name; +// } variable; + +// enum ArithOperandType { +// AR_OPERAND_CONSTANT, +// AR_OPERAND_VARIADIC, +// AR_OPERAND_PARAMETER, +// AR_OPERAND_VARIABLE, +// } type; + +// ArithOperandNode() = default; + +// ArithOperandNode(CScalarData &&data) : constant(std::move(data)) { +// std::cout<<"use move constructor"< +void checkedAnyCast(const std::any& s, TargetType& d) { + try { + d = std::any_cast(s); + } catch (...) { + // TODO(lingsu): remove in future + assert(false); + } +} + +class ExprEvaluator : public geax::frontend::AstExprNodeVisitorImpl { + public: + ExprEvaluator() = delete; + + ExprEvaluator(geax::frontend::Expr* expr, const SymbolTable* sym_tab) + : expr_(expr), sym_tab_(sym_tab) {} + + ~ExprEvaluator() = default; + + std::vector agg_exprs_; + std::vector> agg_ctxs_; + size_t agg_pos_; + + enum class VisitMode { + EVALUATE, + AGGREGATE, + } visit_mode_; + + CEntry Evaluate(RTContext* ctx, const CRecord* record) { + ctx_ = ctx; + record_ = record; + agg_pos_ = 0; + visit_mode_ = VisitMode::EVALUATE; + CEntry entry; + checkedAnyCast(expr_->accept(*this), entry); + return entry; + } + + void Aggregate(RTContext* ctx, const CRecord* record) { + ctx_ = ctx; + record_ = record; + visit_mode_ = VisitMode::AGGREGATE; + if (agg_exprs_.empty()) { + agg_exprs_ = AstAggExprDetector::GetAggExprs(expr_); + } + for (size_t i = 0; i < agg_exprs_.size(); i++) { + agg_pos_ = i; + agg_exprs_[i]->accept(*this); + } + } + + void Reduce() { + for (auto agg_ctx : agg_ctxs_) { + agg_ctx->ReduceNext(); + } + } + + geax::frontend::Expr* GetExpression() { + return expr_; + } + + private: + std::any visit(geax::frontend::GetField* node) override; + std::any visit(geax::frontend::TupleGet* node) override; + std::any visit(geax::frontend::Not* node) override; + std::any visit(geax::frontend::Neg* node) override; + std::any visit(geax::frontend::Tilde* node) override; + std::any visit(geax::frontend::VSome* node) override; + std::any visit(geax::frontend::BEqual* node) override; + std::any visit(geax::frontend::BNotEqual* node) override; + std::any visit(geax::frontend::BGreaterThan* node) override; + std::any visit(geax::frontend::BNotSmallerThan* node) override; + std::any visit(geax::frontend::BSmallerThan* node) override; + std::any visit(geax::frontend::BNotGreaterThan* node) override; + std::any visit(geax::frontend::BSafeEqual* node) override; + std::any visit(geax::frontend::BAdd* node) override; + std::any visit(geax::frontend::BSub* node) override; + std::any visit(geax::frontend::BDiv* node) override; + std::any visit(geax::frontend::BMul* node) override; + std::any visit(geax::frontend::BMod* node) override; + std::any visit(geax::frontend::BSquare* node) override; + std::any visit(geax::frontend::BAnd* node) override; + std::any visit(geax::frontend::BOr* node) override; + std::any visit(geax::frontend::BXor* node) override; + std::any visit(geax::frontend::BBitAnd* node) override; + std::any visit(geax::frontend::BBitOr* node) override; + std::any visit(geax::frontend::BBitXor* node) override; + std::any visit(geax::frontend::BBitLeftShift* node) override; + std::any visit(geax::frontend::BBitRightShift* node) override; + std::any visit(geax::frontend::BConcat* node) override; + std::any visit(geax::frontend::BIndex* node) override; + std::any visit(geax::frontend::BLike* node) override; + std::any visit(geax::frontend::BIn* node) override; + std::any visit(geax::frontend::If* node) override; + std::any visit(geax::frontend::Function* node) override; + std::any visit(geax::frontend::Case* node) override; + std::any visit(geax::frontend::Cast* node) override; + std::any visit(geax::frontend::MatchCase* node) override; + std::any visit(geax::frontend::AggFunc* node) override; + std::any visit(geax::frontend::BAggFunc* node) override; + std::any visit(geax::frontend::MultiCount* node) override; + std::any visit(geax::frontend::Windowing* node) override; + std::any visit(geax::frontend::MkList* node) override; + std::any visit(geax::frontend::MkMap* node) override; + std::any visit(geax::frontend::MkRecord* node) override; + std::any visit(geax::frontend::MkSet* node) override; + std::any visit(geax::frontend::MkTuple* node) override; + std::any visit(geax::frontend::VBool* node) override; + std::any visit(geax::frontend::VInt* node) override; + std::any visit(geax::frontend::VDouble* node) override; + std::any visit(geax::frontend::VString* node) override; + std::any visit(geax::frontend::VDate* node) override; + std::any visit(geax::frontend::VDatetime* node) override; + std::any visit(geax::frontend::VDuration* node) override; + std::any visit(geax::frontend::VTime* node) override; + std::any visit(geax::frontend::VNull* node) override; + std::any visit(geax::frontend::VNone* node) override; + std::any visit(geax::frontend::Ref* node) override; + std::any visit(geax::frontend::Param* node) override; + std::any visit(geax::frontend::SingleLabel* node) override; + std::any visit(geax::frontend::LabelOr* node) override; + std::any visit(geax::frontend::IsLabeled* node) override; + std::any visit(geax::frontend::IsNull* node) override; + std::any visit(geax::frontend::ListComprehension* node) override; + std::any visit(geax::frontend::Exists* node) override; + + std::any reportError() override; + + private: + std::string error_msg_; + geax::frontend::Expr* expr_; + RTContext* ctx_; + const SymbolTable* sym_tab_; + const CRecord* record_; + std::shared_ptr agg_func_; +}; + +struct CExprNode { + static constexpr const char* type_name = "ArithExprNode"; + // ArithOperandNode operand_; + // ArithExprNode* left_; + // ArithExprNode* right_; + std::shared_ptr evaluator_; + // OpType op_; + + CExprNode() = default; + + inline CEntry Eval(cypher::RTContext *ctx, const CRecord &record) { + return evaluator_->Evaluate(ctx, &record); + } +}; +} // namespace compilation +} // namespace cypher \ No newline at end of file diff --git a/src/cypher/experimental/expressions/kernal/binary.cpp b/src/cypher/experimental/expressions/kernal/binary.cpp new file mode 100644 index 0000000000..187559715d --- /dev/null +++ b/src/cypher/experimental/expressions/kernal/binary.cpp @@ -0,0 +1,181 @@ +#include +#include +#include "cypher/cypher_types.h" +#include "core/data_type.h" + +#include "cypher/cypher_exception.h" +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/utils/geax_util.h" +#include "cypher/experimental/expressions/cexpr.h" + +namespace cypher { +namespace compilation { +CFieldData CFieldData::operator+(const CFieldData &other) const { + if (is_null() || other.is_null()) return CFieldData(); + CFieldData ret; + if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { + CYPHER_TODO(); + } else if (is_string() || other.is_string()) { + CYPHER_TODO(); + } else if ((is_integer() || is_real()) && (other.is_integer() || other.is_real())) { + if (is_integer() && other.is_integer()) { + ret.scalar = CScalarData(scalar.Int64() + other.scalar.Int64()); + std::cout<<"ret"< x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); + dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + ret.scalar = std::move(CScalarData(x_n + y_n)); + } + } + return ret; +} + +CFieldData CFieldData::operator-(const CFieldData &other) const { + if (is_null() || other.is_null()) return CFieldData(); + CFieldData ret; + if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { + CYPHER_TODO(); + } else if (is_string() || other.is_string()) { + CYPHER_TODO(); + } else if ((is_integer() || is_real()) && (other.is_integer() || other.is_real())) { + if (is_integer() && other.is_integer()) { + ret.scalar = CScalarData(scalar.Int64() - other.scalar.Int64()); + } else { + dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); + dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + ret.scalar = std::move(CScalarData(x_n - y_n)); + } + } + return ret; +} + +static CFieldData add(const CFieldData& x, const CFieldData& y) { + return x + y; +} + +static CFieldData sub(const CFieldData& x, const CFieldData& y) { + return x - y; +} + +static CFieldData div(const CFieldData& x, const CFieldData y) { + if (is_null() || other.is_null()) return CFieldData(); + if (!(x.is_integer() || x.is_real()) || !(y.is_integer() || y.is_real())) + throw lgraph::CypherException("Type mismatch: expect Integer or Float in div expr"); + CFieldData ret; + if (is_integer() && other.is_integer()) { + dyn_var x_n = x.scalar.integer(); + dyn_var y_n = y.scalar.integer(); + if (y_n == 0) throw lgraph::CypherException("divide by zero"); + ret.scalar = std::move(CScalarData(x_n / y_n)); + } else { + dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); + dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + if (y_n == 0) CYPHER_TODO(); + ret.scalar = std::move(CScalarData(x_n - y_n)); + } + return ret; +} + +#ifndef DO_BINARY_EXPR +#define DO_BINARY_EXPR(func) \ + auto lef = std::any_cast(node->left()->accept(*this)); \ + auto rig = std::any_cast(node->right()->accept(*this)); \ + if (lef.type_ != CEntry::RecordEntryType::CONSTANT || \ + rig.type_ != CEntry::RecordEntryType::CONSTANT) { \ + NOT_SUPPORT_AND_THROW(); \ + } \ + return CEntry(func(lef.constant_, rig.constant_)); +#endif + +std::any ExprEvaluator::visit(geax::frontend::BAdd* node) { DO_BINARY_EXPR(add); } + +std::any ExprEvaluator::visit(geax::frontend::Ref* node) { + auto it = sym_tab_->symbols.find(node->name()); + if (it == sym_tab_->symbols.end()) NOT_SUPPORT_AND_THROW(); + switch (it->second.type) { + case SymbolNode::NODE: + case SymbolNode::RELATIONSHIP: + case SymbolNode::CONSTANT: + case SymbolNode::PARAMETER: + return record_->values[it->second.id]; + case SymbolNode::NAMED_PATH: + { + // auto it = sym_tab_->anot_collection.path_elements.find(node->name()); + // if (it == sym_tab_->anot_collection.path_elements.end()) + // throw lgraph::CypherException("path_elements error: " + node->name()) + // const std::vector>& elements = it->second; + // std::vector params; + // for (auto ref: elements) { + // params.emplace_back(ref.get(), *sym_tab_); + // } + CYPHER_TODO(); + } + } + return std::any(); +} + +std::any ExprEvaluator::visit(geax::frontend::GetField* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::TupleGet* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Not* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Neg* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Tilde* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VSome* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BGreaterThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotSmallerThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSmallerThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotGreaterThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSafeEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSub* node) { DO_BINARY_EXPR(sub); } +std::any ExprEvaluator::visit(geax::frontend::BDiv* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BMul* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BMod* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSquare* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BAnd* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BXor* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitAnd* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitXor* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitLeftShift* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitRightShift* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BConcat* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BIndex* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BLike* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BIn* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::If* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Function* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Case* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Cast* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MatchCase* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::AggFunc* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BAggFunc* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MultiCount* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Windowing* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkList* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkMap* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkRecord* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkSet* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkTuple* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VBool* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VInt* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDouble* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VString* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDate* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDatetime* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDuration* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VTime* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VNull* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VNone* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Param* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::SingleLabel* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::LabelOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::IsLabeled* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::IsNull* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::ListComprehension* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Exists* node) { CYPHER_TODO(); } +std::any ExprEvaluator::reportError() { CYPHER_TODO(); } +} // namespace compilation +} // namepsace cypher \ No newline at end of file diff --git a/toolkits/CMakeLists.txt b/toolkits/CMakeLists.txt index f21862c42c..9e6394250a 100644 --- a/toolkits/CMakeLists.txt +++ b/toolkits/CMakeLists.txt @@ -1,6 +1,20 @@ cmake_minimum_required(VERSION 3.13) project(TuGraph C CXX) +############### lgraph_compilation ################ +set(TARGET_LGRAPH_COMPILATION lgraph_compilation) + +add_executable(${TARGET_LGRAPH_COMPILATION} + lgraph_compilation.cpp) + +target_include_directories(${TARGET_LGRAPH_COMPILATION} PUBLIC + ${CMAKE_SOURCE_DIR}/deps/buildit/include) + +target_link_libraries(${TARGET_LGRAPH_COMPILATION} + lgraph_cypher_lib + ${CMAKE_SOURCE_DIR}/deps/buildit/build/libbuildit.a + librocksdb.a) + ############### lgraph_import ###################### set(TARGET_LGRAPH_IMPORT lgraph_import) diff --git a/toolkits/lgraph_compilation.cpp b/toolkits/lgraph_compilation.cpp new file mode 100644 index 0000000000..6c9d644423 --- /dev/null +++ b/toolkits/lgraph_compilation.cpp @@ -0,0 +1,54 @@ +#include +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/experimental/expressions/cexpr.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/execution_plan/runtime_context.h" +#include "geax-front-end/ast/Ast.h" +#include "blocks/c_code_generator.h" +#include "builder/builder.h" +#include "builder/builder_context.h" +#include "builder/dyn_var.h" +using namespace cypher::compilation; +using builder::static_var; +using builder::dyn_var; + +dyn_var bar(void) { + std::variant, static_var> a; + std::variant, dyn_var> b; + a = (std::variant, static_var>)static_var(10); + b = dyn_var(10); + auto res = std::get>(a) + std::get>(b); + return res; +} + +dyn_var foo(void) { + cypher::SymbolTable sym_tab; + + CFieldData a(std::move(CScalarData(10, false))); + geax::frontend::Ref ref1; + ref1.setName(std::string("a")); + sym_tab.symbols.emplace("a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + CFieldData b(static_var(10)); + geax::frontend::Ref ref2; + ref2.setName(std::string("b")); + sym_tab.symbols.emplace("b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + geax::frontend::BAdd add; + add.setLeft((geax::frontend::Expr*)&ref1); + add.setRight((geax::frontend::Expr*)&ref2); + CRecord record; + record.values.push_back(CEntry(a)); + record.values.push_back(CEntry(b)); + + ExprEvaluator evaluator(&add, &sym_tab); + cypher::RTContext ctx; + return evaluator.Evaluate(&ctx, &record).constant_.scalar.Int64(); +} + +int main() { + builder::builder_context context; + block::c_code_generator::generate_code(context.extract_function_ast(foo, "foo"), std::cout, 0); + return 0; +} \ No newline at end of file From 29b579bf3cbebf4b6d6822f71d314d53d04584dc Mon Sep 17 00:00:00 2001 From: RTEnzyme Date: Wed, 25 Sep 2024 01:22:40 +0000 Subject: [PATCH 02/12] execution framework. --- src/cypher/experimental/expressions/kernal/binary.cpp | 11 +++++------ toolkits/lgraph_compilation.cpp | 2 ++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/cypher/experimental/expressions/kernal/binary.cpp b/src/cypher/experimental/expressions/kernal/binary.cpp index 187559715d..0f8504cd40 100644 --- a/src/cypher/experimental/expressions/kernal/binary.cpp +++ b/src/cypher/experimental/expressions/kernal/binary.cpp @@ -21,7 +21,6 @@ CFieldData CFieldData::operator+(const CFieldData &other) const { } else if ((is_integer() || is_real()) && (other.is_integer() || other.is_real())) { if (is_integer() && other.is_integer()) { ret.scalar = CScalarData(scalar.Int64() + other.scalar.Int64()); - std::cout<<"ret"< x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); @@ -59,18 +58,18 @@ static CFieldData sub(const CFieldData& x, const CFieldData& y) { } static CFieldData div(const CFieldData& x, const CFieldData y) { - if (is_null() || other.is_null()) return CFieldData(); + if (x.is_null() || y.is_null()) return CFieldData(); if (!(x.is_integer() || x.is_real()) || !(y.is_integer() || y.is_real())) throw lgraph::CypherException("Type mismatch: expect Integer or Float in div expr"); CFieldData ret; - if (is_integer() && other.is_integer()) { + if (x.is_integer() && y.is_integer()) { dyn_var x_n = x.scalar.integer(); dyn_var y_n = y.scalar.integer(); if (y_n == 0) throw lgraph::CypherException("divide by zero"); ret.scalar = std::move(CScalarData(x_n / y_n)); } else { - dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); - dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + dyn_var x_n = x.is_integer() ? (dyn_var) x.scalar.integer() : x.scalar.real(); + dyn_var y_n = y.is_integer()? (dyn_var) y.scalar.integer() : y.scalar.real(); if (y_n == 0) CYPHER_TODO(); ret.scalar = std::move(CScalarData(x_n - y_n)); } @@ -129,7 +128,7 @@ std::any ExprEvaluator::visit(geax::frontend::BSmallerThan* node) { CYPHER_TODO( std::any ExprEvaluator::visit(geax::frontend::BNotGreaterThan* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::BSafeEqual* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::BSub* node) { DO_BINARY_EXPR(sub); } -std::any ExprEvaluator::visit(geax::frontend::BDiv* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BDiv* node) { DO_BINARY_EXPR(div); } std::any ExprEvaluator::visit(geax::frontend::BMul* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::BMod* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::BSquare* node) { CYPHER_TODO(); } diff --git a/toolkits/lgraph_compilation.cpp b/toolkits/lgraph_compilation.cpp index 6c9d644423..268f8e4864 100644 --- a/toolkits/lgraph_compilation.cpp +++ b/toolkits/lgraph_compilation.cpp @@ -49,6 +49,8 @@ dyn_var foo(void) { int main() { builder::builder_context context; + std::cout<<"#include "< Date: Wed, 25 Sep 2024 06:58:52 +0000 Subject: [PATCH 03/12] remove static_var usage in data structures. --- .gitmodules | 3 + .../experimental/data_type/field_data.h | 100 ++++++------------ toolkits/lgraph_compilation.cpp | 2 +- 3 files changed, 35 insertions(+), 70 deletions(-) diff --git a/.gitmodules b/.gitmodules index 91a0a8c3ea..85d6bd30f9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "deps/tugraph-db-browser"] path = deps/tugraph-db-browser url = https://github.com/TuGraph-family/tugraph-db-browser.git +[submodule "deps/buildit"] + path = deps/buildit + url = "https://github.com/RTEnzyme/buildit.git" \ No newline at end of file diff --git a/src/cypher/experimental/data_type/field_data.h b/src/cypher/experimental/data_type/field_data.h index ddf4c19589..8ebb22cf0a 100644 --- a/src/cypher/experimental/data_type/field_data.h +++ b/src/cypher/experimental/data_type/field_data.h @@ -14,24 +14,9 @@ using lgraph::FieldType; namespace cypher { namespace compilation { -#define GET_CONSTANT(type) \ - if (is_constant_) { \ - return std::get>(constant_); \ - } else { \ - return std::get>(dyn_); \ - } struct CScalarData { static constexpr const char* type_name = "CScalarData"; - std::variant< - std::monostate, // Represent the null state - static_var, - static_var, - static_var, - static_var, - static_var - > constant_; - std::variant< std::monostate, // Represent the null state dyn_var, @@ -39,82 +24,63 @@ struct CScalarData { dyn_var, dyn_var, dyn_var - > dyn_; + > constant_; - bool is_constant_; - // for string type lgraph::FieldType type_; CScalarData() { type_ = lgraph_api::FieldType::NUL; - is_constant_ = false; } CScalarData(CScalarData &&data) - : constant_(std::move(data.constant_)), dyn_(std::move(data.dyn_)), - is_constant_(data.is_constant_), type_(data.type_) {} + : constant_(std::move(data.constant_)), type_(data.type_) {} CScalarData(const CScalarData& other) - : constant_(other.constant_), dyn_(other.dyn_), - is_constant_(other.is_constant_), type_(other.type_) {} + : constant_(other.constant_), type_(other.type_) {} - CScalarData(const lgraph::FieldData& other, bool is_constant = false) { + CScalarData(const lgraph::FieldData& other) { type_ = other.type; - is_constant_ = is_constant; - if (is_constant_) { - switch (other.type) { - case lgraph::FieldType::NUL: - constant_.emplace(); - break; - case lgraph::FieldType::INT64: - constant_.emplace>((long)other.integer()); - break; - default: - CYPHER_TODO(); - } - } else { - switch (other.type) { - case lgraph::FieldType::NUL: - dyn_.emplace(); - break; - case lgraph::FieldType::INT64: - dyn_.emplace>((long)other.integer()); - break; - default: - CYPHER_TODO(); - } - + switch (other.type) { + case lgraph::FieldType::NUL: + constant_.emplace(); + break; + case lgraph::FieldType::INT64: + constant_.emplace>((long)other.integer()); + break; + default: + CYPHER_TODO(); } } - explicit CScalarData(long integer, bool is_constant = true) { - is_constant_ = is_constant; - if (is_constant) { - constant_.emplace>(integer); - } else { - dyn_.emplace>(integer); - } + explicit CScalarData(long integer) { + constant_.emplace>(integer); type_ = lgraph::FieldType::INT64; } explicit CScalarData(const static_var &integer) - : constant_(integer), is_constant_(true), type_(FieldType::INT64) {} + : type_(FieldType::INT64) { + constant_ = (dyn_var) integer; + } explicit CScalarData(const dyn_var &integer) - : dyn_(integer), is_constant_(false), type_(FieldType::INT64) {} + : constant_(integer), type_(FieldType::INT64) {} explicit CScalarData(dyn_var&& integer) - : dyn_(std::move(integer)), is_constant_(false), type_(FieldType::INT64) {} + : constant_(std::move(integer)), type_(FieldType::INT64) {} inline dyn_var integer() const { switch (type_) { case FieldType::NUL: case FieldType::BOOL: throw std::bad_cast(); - case FieldType::INT8: GET_CONSTANT(short) - case FieldType::INT16: GET_CONSTANT(short) - case FieldType::INT32: GET_CONSTANT(int) - case FieldType::INT64: GET_CONSTANT(long) + case FieldType::INT8: + return std::get>(constant_); + case FieldType::INT16: + return std::get>(constant_); + case FieldType::INT32: + return std::get>(constant_); + case FieldType::INT64: + return std::get>(constant_); case FieldType::FLOAT: case FieldType::DOUBLE: case FieldType::DATE: @@ -141,9 +107,9 @@ struct CScalarData { case FieldType::INT64: throw std::bad_cast(); case FieldType::FLOAT: - GET_CONSTANT(float); + std::get>(constant_); case FieldType::DOUBLE: - GET_CONSTANT(double); + std::get>(constant_); case FieldType::DATE: case FieldType::DATETIME: case FieldType::STRING: @@ -159,7 +125,7 @@ struct CScalarData { } dyn_var Int64() const { - GET_CONSTANT(long) + return std::get>(constant_); } inline bool is_integer() const { @@ -177,9 +143,7 @@ struct CScalarData { CScalarData& operator=(CScalarData&& other) noexcept { if (this != &other) { constant_ = std::move(other.constant_); - dyn_ = std::move(other.dyn_); type_ = std::move(other.type_); - is_constant_ = std::move(other.is_constant_); } return *this; } @@ -187,9 +151,7 @@ struct CScalarData { CScalarData& operator=(const CScalarData& other) { if (this != &other) { constant_ = other.constant_; - dyn_ = other.dyn_; type_ = other.type_; - is_constant_ = other.is_constant_; } return *this; } diff --git a/toolkits/lgraph_compilation.cpp b/toolkits/lgraph_compilation.cpp index 268f8e4864..9a33388150 100644 --- a/toolkits/lgraph_compilation.cpp +++ b/toolkits/lgraph_compilation.cpp @@ -25,7 +25,7 @@ dyn_var bar(void) { dyn_var foo(void) { cypher::SymbolTable sym_tab; - CFieldData a(std::move(CScalarData(10, false))); + CFieldData a(std::move(CScalarData(10))); geax::frontend::Ref ref1; ref1.setName(std::string("a")); sym_tab.symbols.emplace("a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); From 43b05d232b4b35ad21c5b2599dcbb3dafb9fcafe Mon Sep 17 00:00:00 2001 From: RTEnzyme Date: Wed, 25 Sep 2024 09:00:11 +0000 Subject: [PATCH 04/12] test framework. --- test/CMakeLists.txt | 4 ++ test/test_query_compilation.cpp | 109 ++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 test/test_query_compilation.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b723634227..04c9a9317d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -91,6 +91,7 @@ add_executable(unit_test test_proto_convert.cpp test_python_plugin_manager.cpp test_python_plugin_manager_impl.cpp + test_query_compilation.cpp test_query.cpp test_random_delete.cpp test_rest_client.cpp @@ -138,9 +139,12 @@ add_executable(unit_test ${LGRAPH_ROOT_DIR}/src/import/import_client.cpp ) +target_include_directories(unit_test PRIVATE ${CMAKE_SOURCE_DIR}/deps/buildit/include) + target_link_libraries(unit_test lgraph_server_lib geax_isogql + ${CMAKE_SOURCE_DIR}/deps/buildit/build/libbuildit.a bolt librocksdb.a) diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp new file mode 100644 index 0000000000..753b87ce0d --- /dev/null +++ b/test/test_query_compilation.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/experimental/expressions/cexpr.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/execution_plan/runtime_context.h" +#include "geax-front-end/ast/Ast.h" + +#include "blocks/c_code_generator.h" +#include "builder/builder.h" +#include "builder/builder_context.h" +#include "builder/dyn_var.h" +#include "builder/static_var.h" +using builder::dyn_var; +using builder::static_var; +using cypher::compilation::CFieldData; +using cypher::compilation::CScalarData; +using cypher::compilation::CRecord; +using cypher::compilation::CEntry; + +#include "gtest/gtest.h" + +#include "core/value.h" +#include "./ut_utils.h" + +std::string execute(const std::string& command) { + std::string result; + FILE* pipe = popen(command.c_str(), "r"); + if (!pipe) { + std::cerr << "popen() failed!" << std::endl; + return ""; + } + char buf[128]; + while (fgets(buf, sizeof(buf), pipe) != nullptr) { + result += buf; + } + pclose(pipe); + return result; +} + +std::string execute_func(std::string &func_body) { + const std::string file_name = "/tmp/a.cpp"; + std::ofstream out_file(file_name); + if (!out_file) { + std::cerr << "Failed to open file for writing!" << std::endl; + return ""; + } + out_file << func_body; + out_file.close(); + + // define and execute compiler commands + std::string compile_cmd = "g++ " + file_name + " -o /tmp/a"; + int compile_res = system(compile_cmd.c_str()); + if (compile_res != 0) { + std::cerr << "Compilation failed!" << std::endl; + return ""; + } + + // define and execute command + std::string output = execute("/tmp/a"); + + return output; +} +class TestQueryCompilation : public TuGraphTest {}; + +dyn_var add(void) { + cypher::SymbolTable sym_tab; + + CFieldData a(std::move(CScalarData(10))); + geax::frontend::Ref ref1; + ref1.setName(std::string("a")); + sym_tab.symbols.emplace("a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + CFieldData b(static_var(10)); + geax::frontend::Ref ref2; + ref2.setName(std::string("b")); + sym_tab.symbols.emplace("b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + geax::frontend::BAdd add; + add.setLeft((geax::frontend::Expr*)&ref1); + add.setRight((geax::frontend::Expr*)&ref2); + CRecord record; + record.values.push_back(CEntry(a)); + record.values.push_back(CEntry(b)); + + cypher::compilation::ExprEvaluator evaluator(&add, &sym_tab); + cypher::RTContext ctx; + return evaluator.Evaluate(&ctx, &record).constant_.scalar.Int64(); +}; + +TEST_F(TestQueryCompilation, Add) { + builder::builder_context context; + auto ast = context.extract_function_ast(add, "add"); + std::ostringstream oss; + oss << "#include \n"; + block::c_code_generator::generate_code(ast, oss, 0); + oss << "int main() {\n std::cout << add();\n return 0;\n}"; + std::string body = oss.str(); + std::cout <<"Generated code: \n" << body << std::endl; + std::string res = execute_func(body); + ASSERT_EQ(res, "20"); +} \ No newline at end of file From e6702f2cf3a0ab98524fad88d9d79d97e0a6ac3e Mon Sep 17 00:00:00 2001 From: RTEnzyme Date: Wed, 25 Sep 2024 10:52:26 +0000 Subject: [PATCH 05/12] generate code on current directory --- test/test_query_compilation.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp index 753b87ce0d..f3697a7915 100644 --- a/test/test_query_compilation.cpp +++ b/test/test_query_compilation.cpp @@ -46,7 +46,7 @@ std::string execute(const std::string& command) { } std::string execute_func(std::string &func_body) { - const std::string file_name = "/tmp/a.cpp"; + const std::string file_name = "a.cpp"; std::ofstream out_file(file_name); if (!out_file) { std::cerr << "Failed to open file for writing!" << std::endl; @@ -56,7 +56,7 @@ std::string execute_func(std::string &func_body) { out_file.close(); // define and execute compiler commands - std::string compile_cmd = "g++ " + file_name + " -o /tmp/a"; + std::string compile_cmd = "g++ " + file_name + " -o ./a"; int compile_res = system(compile_cmd.c_str()); if (compile_res != 0) { std::cerr << "Compilation failed!" << std::endl; @@ -64,7 +64,7 @@ std::string execute_func(std::string &func_body) { } // define and execute command - std::string output = execute("/tmp/a"); + std::string output = execute("./a"); return output; } From dca7e274b409258598f7a84f2198c669c702b548 Mon Sep 17 00:00:00 2001 From: RTEnzyme Date: Wed, 25 Sep 2024 10:57:53 +0000 Subject: [PATCH 06/12] delete files after executions. --- test/test_query_compilation.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp index f3697a7915..c1c4575e26 100644 --- a/test/test_query_compilation.cpp +++ b/test/test_query_compilation.cpp @@ -46,7 +46,8 @@ std::string execute(const std::string& command) { } std::string execute_func(std::string &func_body) { - const std::string file_name = "a.cpp"; + const std::string file_name = "test_add.cpp"; + const std::string output_name = "test_add" std::ofstream out_file(file_name); if (!out_file) { std::cerr << "Failed to open file for writing!" << std::endl; @@ -54,18 +55,20 @@ std::string execute_func(std::string &func_body) { } out_file << func_body; out_file.close(); - // define and execute compiler commands - std::string compile_cmd = "g++ " + file_name + " -o ./a"; + std::string compile_cmd = "g++ " + file_name + " -o " + output_name; int compile_res = system(compile_cmd.c_str()); if (compile_res != 0) { std::cerr << "Compilation failed!" << std::endl; return ""; } - // define and execute command std::string output = execute("./a"); - + // delete files + if (std::remove(file_name) && std::remove(output_name)) { + std::cerr << "Failed to delete files: " << file_name + << ", " << output_name << std::endl; + } return output; } class TestQueryCompilation : public TuGraphTest {}; From a06bf22b0b7dee60e120dded4b14ff33383d976e Mon Sep 17 00:00:00 2001 From: RTEnzyme Date: Wed, 25 Sep 2024 12:31:46 +0000 Subject: [PATCH 07/12] fix bugs in test framework. --- test/test_query_compilation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp index c1c4575e26..f1cb88a265 100644 --- a/test/test_query_compilation.cpp +++ b/test/test_query_compilation.cpp @@ -47,7 +47,7 @@ std::string execute(const std::string& command) { std::string execute_func(std::string &func_body) { const std::string file_name = "test_add.cpp"; - const std::string output_name = "test_add" + const std::string output_name = "test_add"; std::ofstream out_file(file_name); if (!out_file) { std::cerr << "Failed to open file for writing!" << std::endl; @@ -65,7 +65,7 @@ std::string execute_func(std::string &func_body) { // define and execute command std::string output = execute("./a"); // delete files - if (std::remove(file_name) && std::remove(output_name)) { + if (std::remove(file_name.c_str()) && std::remove(output_name.c_str())) { std::cerr << "Failed to delete files: " << file_name << ", " << output_name << std::endl; } From 9747be9f8c6bae5bc2647cdfda049ff9835b8076 Mon Sep 17 00:00:00 2001 From: RTEnzyme <52275903001@stu.ecnu.edu.cn> Date: Sat, 28 Sep 2024 07:04:39 +0000 Subject: [PATCH 08/12] LLVM framework --- src/BuildCypherLib.cmake | 10 +++- src/cypher/experimental/jit/TuJIT.cpp | 2 + src/cypher/experimental/jit/TuJIT.h | 79 +++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 src/cypher/experimental/jit/TuJIT.cpp create mode 100644 src/cypher/experimental/jit/TuJIT.h diff --git a/src/BuildCypherLib.cmake b/src/BuildCypherLib.cmake index 90e47c1178..e8e155163e 100644 --- a/src/BuildCypherLib.cmake +++ b/src/BuildCypherLib.cmake @@ -4,6 +4,7 @@ find_package(PythonInterp 3) find_package(PythonLibs ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR} EXACT REQUIRED) #antlr4-runtime find_package(antlr4-runtime REQUIRED) +find_package(LLVM REQUIRED CONFIG) set(ANTRL4_LIBRARY antlr4-runtime.a) set(TARGET_LGRAPH_CYPHER_LIB lgraph_cypher_lib) @@ -59,6 +60,7 @@ set(LGRAPH_CYPHER_SRC # find cypher/ -name "*.cpp" | sort cypher/experimental/data_type/field_data.h cypher/experimental/expressions/cexpr.cpp cypher/experimental/expressions/kernal/binary.cpp + cypher/experimental/jit/TuJIT.cpp cypher/filter/filter.cpp cypher/filter/iterator.cpp cypher/graph/graph.cpp @@ -91,7 +93,10 @@ target_include_directories(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC ${ANTLR4_INCLUDE_DIR} ${CMAKE_CURRENT_LIST_DIR}/cypher) -include_directories(${CMAKE_SOURCE_DIR}/deps/buildit/include) +include_directories( + ${CMAKE_SOURCE_DIR}/deps/buildit/include + ${LLVM_INCLUDE_DIRS}) +add_definitions(${LLVM_DEFINITIONS}) target_link_directories(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC ${CMAKE_SOURCE_DIR}/deps/buildit/lib) @@ -104,3 +109,6 @@ target_link_libraries(${TARGET_LGRAPH_CYPHER_LIB} PUBLIC target_link_libraries(${TARGET_LGRAPH_CYPHER_LIB} PRIVATE lgraph_server_lib) + +llvm_map_components_to_libnames(llvm_libs Core Support) +target_link_libraries(${TARGET_LGRAPH_CYPHER_LIB} PRIVATE ${llvm_libs}) \ No newline at end of file diff --git a/src/cypher/experimental/jit/TuJIT.cpp b/src/cypher/experimental/jit/TuJIT.cpp new file mode 100644 index 0000000000..78ad5f1237 --- /dev/null +++ b/src/cypher/experimental/jit/TuJIT.cpp @@ -0,0 +1,2 @@ +#include "cypher/experimental/jit/TuJIT.h" + diff --git a/src/cypher/experimental/jit/TuJIT.h b/src/cypher/experimental/jit/TuJIT.h new file mode 100644 index 0000000000..b90a4620e3 --- /dev/null +++ b/src/cypher/experimental/jit/TuJIT.h @@ -0,0 +1,79 @@ +#include +#include +#include + +#include +#include +#include + +namespace cypher { +namespace compilation { +class JITCompiler; +class JITSymbolResolver; +class JITModuleMemoryManager; + +/** Custom JIT implementation inspired by CHJIT in clickhouse + * Main use cases: + * 1. Compiled functions in module. + * 2. Release memory for compiled function. + */ +class TuJIT { + public: + TuJIT(); + + ~TuJIT(); + + struct CompileModule { + // Size of compiled module code in bytes + size_t size_; + + // Module identifier. Should not be changed by client + uint64_t identifier_; + + // Vector of compiled functions. Should not be changed by client. + // It is client responsibility to cast result function to right signature. + // After call to deleteCompiledModule compiled functions from module become invalid. + std::unordered_map function_name_to_symbol_; + }; + + // Compile module. In compile function client responsiblity is to fill module with necessary + // IR code, then it will be compiled by TuJIT instance. + // Return compiled module. + CompileModule compileModule(std::function compile_funciton); + + // Delete compiled module. Pointers to functions from module become invalid after this call. + // It is client reponsibility to be sure that there are no pointers to compiled module code. + void deleteCompiledModule(const CompileModule& module_info); + + // Register external symbol for TuJIT instance to use, during linking. + // It can be function, or global constant. + // It is client responsibility to be sure that address of symbol is valid during TuJIT instance lifetime. + void registerExternalSymbol(const std::string& symbol_name, void* address); + + // Total compiled code size for module that are current valid. + size_t getCompiledCodeSize() const { return compiled_code_size_.load(std::memory_order_relaxed); } + + private: + + std::unique_ptr createModulerForCompilation(); + + CompileModule compileModule(std::unique_ptr module); + + std::string getMangleName(const std::string& name_to_mangle) const; + + void runOptimizationPassesOnModule(llvm::Module& module) const; + + static std::unique_ptr getTargetMachine_; + + llvm::LLVMContext context_; + std::unique_ptr machine_; + llvm::DataLayout layout_; + std::unique_ptr compiler_; + + std::unordered_map> module_identifier_to_memory_manager_; + uint64_t current_module_key_ = 0; + std::atomic compiled_code_size_ = 0; + mutable std::mutex jit_lock_; +}; +} // namespace compilation +} // namespace cypher \ No newline at end of file From 162f4fad038c8836016566cb49b3fcfba8dfc07c Mon Sep 17 00:00:00 2001 From: RT_Enzyme <52275903001@stu.ecnu.edu.cn> Date: Mon, 28 Oct 2024 20:11:30 +0800 Subject: [PATCH 09/12] Query plan cache for cypher queries. (#676) * plan cache codebase. * plan cache codebase. * split plan_cache.h and plan_cache_param. * integrate plan_cache into execution process. * add test cases for parameterized query execution. * fail direction * plan cache codebase * add more pattern for fastQueryParam * add more pattern for fastQueryParam * fix lint error * remove buildit deps. * fix bug in cypher visitor * remove unused dir --------- Co-authored-by: Ke Huang <569078986@qq.com> --- src/BuildCypherLib.cmake | 2 + .../arithmetic/arithmetic_expression.cpp | 40 +++-- src/cypher/arithmetic/arithmetic_expression.h | 8 + src/cypher/execution_plan/execution_plan.cpp | 7 +- src/cypher/execution_plan/ops/op_argument.cpp | 4 +- .../locate_node_by_indexed_prop.h | 3 +- .../execution_plan/plan_cache/plan_cache.cpp | 14 ++ .../execution_plan/plan_cache/plan_cache.h | 108 ++++++++++++ .../plan_cache/plan_cache_param.cpp | 165 ++++++++++++++++++ .../plan_cache/plan_cache_param.h | 28 +++ src/cypher/execution_plan/runtime_context.h | 3 + src/cypher/execution_plan/scheduler.cpp | 100 ++++++----- src/cypher/execution_plan/scheduler.h | 5 +- src/cypher/graph/common.h | 2 +- src/cypher/parser/clause.h | 2 + src/cypher/parser/cypher_base_visitor.h | 14 ++ src/cypher/parser/symbol_table.h | 1 + test/CMakeLists.txt | 1 + test/test_plan_cache.cpp | 59 +++++++ 19 files changed, 502 insertions(+), 64 deletions(-) create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache.cpp create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache.h create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache_param.cpp create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache_param.h create mode 100644 test/test_plan_cache.cpp diff --git a/src/BuildCypherLib.cmake b/src/BuildCypherLib.cmake index 9928aed2a1..87568c2890 100644 --- a/src/BuildCypherLib.cmake +++ b/src/BuildCypherLib.cmake @@ -55,6 +55,8 @@ set(LGRAPH_CYPHER_SRC # find cypher/ -name "*.cpp" | sort cypher/execution_plan/ops/op_node_by_id_seek.cpp cypher/execution_plan/ops/op_traversal.cpp cypher/execution_plan/ops/op_gql_remove.cpp + cypher/execution_plan/plan_cache/plan_cache_param.cpp + cypher/execution_plan/plan_cache/plan_cache.cpp cypher/execution_plan/scheduler.cpp cypher/filter/filter.cpp cypher/filter/iterator.cpp diff --git a/src/cypher/arithmetic/arithmetic_expression.cpp b/src/cypher/arithmetic/arithmetic_expression.cpp index 93e87ea553..f6828a785d 100644 --- a/src/cypher/arithmetic/arithmetic_expression.cpp +++ b/src/cypher/arithmetic/arithmetic_expression.cpp @@ -1425,13 +1425,20 @@ void ArithOperandNode::SetEntity(const std::string &alias, const std::string &pr } void ArithOperandNode::SetParameter(const std::string ¶m, const SymbolTable &sym_tab) { - type = AR_OPERAND_PARAMETER; - variadic.alias = param; - auto it = sym_tab.symbols.find(param); - if (it == sym_tab.symbols.end()) { - throw lgraph::CypherException("Parameter not defined: " + param); + if (std::isdigit(param[1])) { + // query plan parameter + type = AR_OPERAND_CONSTANT; + constant = sym_tab.param_tab_->at(param); + } else { + // named parameter + type = AR_OPERAND_PARAMETER; + variadic.alias = param; + auto it = sym_tab.symbols.find(param); + if (it == sym_tab.symbols.end()) { + throw lgraph::CypherException("Parameter not defined: " + param); + } + variadic.alias_idx = it->second.id; } - variadic.alias_idx = it->second.id; } void ArithOperandNode::RealignAliasId(const SymbolTable &sym_tab) { @@ -1441,17 +1448,20 @@ void ArithOperandNode::RealignAliasId(const SymbolTable &sym_tab) { variadic.alias_idx = it->second.id; } -cypher::FieldData GenerateCypherFieldData(const parser::Expression &expr) { - if (expr.type != parser::Expression::LIST && expr.type != parser::Expression::MAP) { +cypher::FieldData GenerateCypherFieldData(const parser::Expression &expr, + const SymbolTable &sym_tab) { + if (expr.type == parser::Expression::PARAMETER) { + return cypher::FieldData(sym_tab.param_tab_->at(expr.String())); + } else if (expr.type != parser::Expression::LIST + && expr.type != parser::Expression::MAP) { return cypher::FieldData(parser::MakeFieldData(expr)); - } - if (expr.type == parser::Expression::LIST) { + } else if (expr.type == parser::Expression::LIST) { std::vector list; - for (auto &e : expr.List()) list.emplace_back(GenerateCypherFieldData(e)); + for (auto &e : expr.List()) list.emplace_back(GenerateCypherFieldData(e, sym_tab)); return cypher::FieldData(std::move(list)); } else { std::unordered_map map; - for (auto &e : expr.Map()) map.emplace(e.first, GenerateCypherFieldData(e.second)); + for (auto &e : expr.Map()) map.emplace(e.first, GenerateCypherFieldData(e.second, sym_tab)); return cypher::FieldData(std::move(map)); } } @@ -1494,13 +1504,13 @@ void ArithOperandNode::Set(const parser::Expression &expr, const SymbolTable &sy { /* e.g. [1,3,5,7], [1,3,5.55,'seven'] */ type = ArithOperandNode::AR_OPERAND_CONSTANT; - constant = GenerateCypherFieldData(expr); + constant = GenerateCypherFieldData(expr, sym_tab); break; } case parser::Expression::MAP: { type = ArithOperandNode::AR_OPERAND_CONSTANT; - constant = GenerateCypherFieldData(expr); + constant = GenerateCypherFieldData(expr, sym_tab); break; } default: @@ -1650,7 +1660,7 @@ void ArithExprNode::Set(const parser::Expression &expr, const SymbolTable &sym_t * [n.name, n.age] as op */ bool is_operand = true; for (auto &e : expr.List()) { - if (!e.IsLiteral()) is_operand = false; + if (!e.IsLiteral() && e.type != parser::Expression::PARAMETER) is_operand = false; if (e.type == parser::Expression::MAP) is_operand = true; } if (!is_operand) { diff --git a/src/cypher/arithmetic/arithmetic_expression.h b/src/cypher/arithmetic/arithmetic_expression.h index ef996ae4ec..e81e291350 100644 --- a/src/cypher/arithmetic/arithmetic_expression.h +++ b/src/cypher/arithmetic/arithmetic_expression.h @@ -718,11 +718,19 @@ struct ArithExprNode { void SetOperand(ArithOperandNode::ArithOperandType operand_type, const cypher::FieldData &data) { + // @todo(anyone) below assertion throws excpetion when set parameter operand. CYPHER_THROW_ASSERT(operand_type == ArithOperandNode::AR_OPERAND_CONSTANT); type = AR_EXP_OPERAND; operand.SetConstant(data); } + void SetOperandParameter(ArithOperandNode::ArithOperandType operand_type, + const std::string ¶m, const SymbolTable &sym_tab) { + CYPHER_THROW_ASSERT(operand_type == ArithOperandNode::AR_OPERAND_PARAMETER); + type = AR_EXP_OPERAND; + operand.SetParameter(param, sym_tab); + } + void SetOperandVariable(ArithOperandNode::ArithOperandType operand_type, const bool &hasMapFieldName = false, const std::string &value_alias = "", const std::string &map_field_name = "") { diff --git a/src/cypher/execution_plan/execution_plan.cpp b/src/cypher/execution_plan/execution_plan.cpp index b58dbfc0b4..53cf9e76b5 100644 --- a/src/cypher/execution_plan/execution_plan.cpp +++ b/src/cypher/execution_plan/execution_plan.cpp @@ -603,8 +603,11 @@ void ExecutionPlan::_BuildExpandOps(const parser::QueryPart &part, PatternGraph pattern_graph.symbol_table); if (pf.type == Property::PARAMETER) { // TODO(anyone) use record - ae2.SetOperand(ArithOperandNode::AR_OPERAND_PARAMETER, - cypher::FieldData(lgraph::FieldData(pf.value_alias))); + // Fix bugs of parameterized execution: + // MATCH (rachel:Person {name: $name1})-[]->(family:Person)-[:ACTED_IN]->(film) + // <-[:ACTED_IN]-(richard:Person {name: $name2}) RETURN family.name; + ae2.SetOperandParameter(ArithOperandNode::AR_OPERAND_PARAMETER, + pf.value_alias, pattern_graph.symbol_table); } else if (pf.type == Property::VARIABLE) { ae2.SetOperandVariable(ArithOperandNode::AR_OPERAND_VARIABLE, pf.hasMapFieldName, pf.value_alias, pf.map_field_name); diff --git a/src/cypher/execution_plan/ops/op_argument.cpp b/src/cypher/execution_plan/ops/op_argument.cpp index 5e622a9224..577cdc859f 100644 --- a/src/cypher/execution_plan/ops/op_argument.cpp +++ b/src/cypher/execution_plan/ops/op_argument.cpp @@ -25,7 +25,9 @@ Argument::Argument(const SymbolTable *sym_tab) : OpBase(OpType::ARGUMENT, "Argument"), sym_tab_(sym_tab) { std::map> ordered_alias; for (auto &a : sym_tab->symbols) { - if (a.second.scope == SymbolNode::ARGUMENT) { + // WITH [$1, $2, $3] AS coll RETURN size(coll) + // Should ignore the query params in symbole table + if (a.second.scope == SymbolNode::ARGUMENT && a.second.type != SymbolNode::PARAMETER) { ordered_alias.emplace(a.second.id, std::make_pair(a.first, a.second.type)); } } diff --git a/src/cypher/execution_plan/optimization/locate_node_by_indexed_prop.h b/src/cypher/execution_plan/optimization/locate_node_by_indexed_prop.h index 8f82f10779..329d860e28 100644 --- a/src/cypher/execution_plan/optimization/locate_node_by_indexed_prop.h +++ b/src/cypher/execution_plan/optimization/locate_node_by_indexed_prop.h @@ -158,7 +158,8 @@ class LocateNodeByIndexedProp : public OptPass { if (in_filter->ae_left_.type == ArithExprNode::AR_EXP_OPERAND && in_filter->ae_left_.operand.type == ArithOperandNode::AR_OPERAND_VARIADIC && - !in_filter->ae_left_.operand.variadic.entity_prop.empty()) { + !in_filter->ae_left_.operand.variadic.entity_prop.empty() && + in_filter->ae_right_.type == ArithExprNode::AR_EXP_OPERAND) { auto right_data = Entry(in_filter->ae_right_.operand.constant); if (!right_data.IsArray()) CYPHER_TODO(); field = in_filter->ae_left_.operand.variadic.entity_prop; diff --git a/src/cypher/execution_plan/plan_cache/plan_cache.cpp b/src/cypher/execution_plan/plan_cache/plan_cache.cpp new file mode 100644 index 0000000000..96bcb31e18 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache.cpp @@ -0,0 +1,14 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/execution_plan/plan_cache/plan_cache.h" diff --git a/src/cypher/execution_plan/plan_cache/plan_cache.h b/src/cypher/execution_plan/plan_cache/plan_cache.h new file mode 100644 index 0000000000..4e50a99bec --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache.h @@ -0,0 +1,108 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#include "parser/clause.h" +#include "parser/data_typedef.h" + +namespace cypher { +class ASTCacheObj { + public: + std::vector stmts; + parser::CmdType cmd; + + ASTCacheObj() {} + + ASTCacheObj(const std::vector &stmts, parser::CmdType cmd) + : stmts(stmts), cmd(cmd) { + } + + std::vector Stmt() { + return stmts; + } + + parser::CmdType CmdType() { + return cmd; + } +}; + +template +class PlanCacheEntry { + public: + std::string key; + T value; + + PlanCacheEntry(const std::string &key, const T &value) : key(key), value(value) {} +}; + +template +class LRUPlanCache { + typedef PlanCacheEntry Entry; + std::list _item_list; + std::unordered_map _item_map; + size_t _max_size; + mutable std::shared_mutex _mutex; + inline void _KickOut() { + while (_item_map.size() > _max_size) { + auto last_it = _item_list.end(); + last_it--; + _item_map.erase(last_it->key); + _item_list.pop_back(); + } + } + + public: + explicit LRUPlanCache(size_t max_size) : _max_size(max_size) {} + + LRUPlanCache() : _max_size(512) {} + + void add_plan(std::string param_query, const Value &val) { + std::unique_lock lock(_mutex); + auto it = _item_map.find(param_query); + if (it == _item_map.end()) { + _item_list.emplace_front(std::move(param_query), val); + _item_map.emplace(_item_list.begin()->key, _item_list.begin()); + _KickOut(); + } else { + // Overwrite the cached value if the query is already present in the cache. + // And move the entry to the front of the list. + it->second->value = val; + _item_list.splice(_item_list.begin(), _item_list, it->second); + } + } + + // Get the cached value for the given parameterized query. Before calling this function, + // you MUST parameterize the query using the fastQueryParam(). + bool get_plan(const std::string ¶m_query, Value &val) { + // parameterized raw query + std::shared_lock lock(_mutex); + auto it = _item_map.find(param_query); + if (it == _item_map.end()) { + return false; + } + _item_list.splice(_item_list.begin(), _item_list, it->second); + val = it->second->value; + return true; + } + + size_t max_size() const { return _max_size; } + + size_t current_size() const { return _item_map.size(); } +}; + +typedef LRUPlanCache ASTCache; +} // namespace cypher diff --git a/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp b/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp new file mode 100644 index 0000000000..0bd98220a6 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp @@ -0,0 +1,165 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/execution_plan/plan_cache/plan_cache_param.h" + +namespace cypher { +std::string fastQueryParam(RTContext *ctx, const std::string query) { + /** + * We don't parameterize the queries or literals in query if: + * 1. The query is a CALL statement. + * 2. limit/skip `n`. + * 3. Range literals: ()->[e*..3]->(m) + * 4. the items in return body: return RETURN a,-2,9.78,'im a string' (@todo) + * 5. match ... create: MATCH (c {name:$0}) CREATE (p:Person {name:$1, birthyear:$2})-[r:BORN_IN]->(c) RETURN p,r,c + */ + antlr4::ANTLRInputStream input(query); + parser::LcypherLexer lexer(&input); + antlr4::CommonTokenStream token_stream(&lexer); + token_stream.fill(); + + std::vector tokens = token_stream.getTokens(); + size_t delete_size = 0; + std::string param_query = query; + + bool prev_limit_skip = false; + bool in_return_body = false; + bool prev_double_dots = false; // e*..3 + bool in_rel = false; // -[n]-> + int param_num = 0; + if (tokens[0]->getType() == parser::LcypherParser::CALL) { + // Don't parameterize plugin CALL statements + return query; + } + for (size_t i = 0; i < tokens.size(); i++) { + parser::Expression expr; + bool is_param; + switch (tokens[i]->getType()) { + case parser::LcypherParser::CREATE: { + // We don't parameterize the Create statements + // Remove the parsed parameters. + for (auto it = ctx->param_tab_.begin(); it!= ctx->param_tab_.end(); ) { + if (it->first[0] == '$' && std::isdigit(it->first[1])) { + it = ctx->param_tab_.erase(it); + } else { + ++it; + } + } + return query; + } + case parser::LcypherParser::T__13: { // '-' + size_t j = i; + while (++j < tokens.size() && tokens[j]->getType() == parser::LcypherParser::SP) { + } + if (j < tokens.size() && tokens[j]->getType() == parser::LcypherParser::T__7) { + in_rel = true; + } + i = j; + break; + } + case parser::LcypherParser::T__8: { // ']' + in_rel = false; + break; + } + case parser::LcypherParser::StringLiteral: { + // String literal + auto str = tokens[i]->getText(); + std::string res; + // remove escape character + for (size_t i = 1; i < str.length() - 1; i++) { + if (str[i] == '\\') { + i++; + } + res.push_back(str[i]); + } + expr.type = parser::Expression::STRING; + expr.data = std::make_shared(std::move(res)); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::HexInteger: + case parser::LcypherParser::DecimalInteger: + case parser::LcypherParser::OctalInteger: { + if (in_rel) { + // The integer literals in relationships are range literals. + // -[:HAS_CHILD*1..]-> + break; + } + if (prev_limit_skip || prev_double_dots) { + break; + } + // Integer literal + expr.type = parser::Expression::DataType::INT; + expr.data = std::stol(tokens[i]->getText()); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::ExponentDecimalReal: + case parser::LcypherParser::RegularDecimalReal: { + // Double literal + expr.type = parser::Expression::DataType::DOUBLE; + expr.data = std::stod(tokens[i]->getText()); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::TRUE_: { + expr.type = parser::Expression::DataType::BOOL; + expr.data = true; + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::FALSE_: { + expr.type = parser::Expression::DataType::BOOL; + expr.data = false; + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::RETURN: { + in_return_body = true; + break; + } + default: + break; + } + + // Replace the token with placeholder + if (is_param) { + if (!in_return_body) { + size_t start_index = tokens[i]->getStartIndex() - delete_size; + size_t end_index = tokens[i]->getStopIndex() - delete_size; + // Indicate the position in raw parameterized query + std::string count = "$" + std::to_string(param_num); + param_query.replace(start_index, end_index - start_index + 1, count); + delete_size += (end_index - start_index + 1) - count.size(); + param_num++; + } + is_param = false; + } + if (tokens[i]->getType() == parser::LcypherParser::LIMIT || + tokens[i]->getType() == parser::LcypherParser::L_SKIP) { + prev_limit_skip = true; + } else if (tokens[i]->getType() == parser::LcypherParser::T__11) { + prev_double_dots = true; + } else if (tokens[i]->getType() < parser::LcypherParser::SP) { + prev_limit_skip = false; + prev_double_dots = false; + } + } + return param_query; +} +} // namespace cypher diff --git a/src/cypher/execution_plan/plan_cache/plan_cache_param.h b/src/cypher/execution_plan/plan_cache/plan_cache_param.h new file mode 100644 index 0000000000..83410f7383 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache_param.h @@ -0,0 +1,28 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "./antlr4-runtime.h" +#include "parser/generated/LcypherLexer.h" +#include "parser/generated/LcypherParser.h" + +#include "execution_plan/runtime_context.h" +#include "parser/clause.h" +#include "parser/expression.h" + +namespace cypher { + +// Leverage the lexer to parameterize queries +std::string fastQueryParam(RTContext *ctx, const std::string query); +} diff --git a/src/cypher/execution_plan/runtime_context.h b/src/cypher/execution_plan/runtime_context.h index 6beda0ef35..aadd1005f5 100644 --- a/src/cypher/execution_plan/runtime_context.h +++ b/src/cypher/execution_plan/runtime_context.h @@ -67,6 +67,9 @@ class RTContext : public SubmitQueryContext { std::unique_ptr result_info_; std::unique_ptr result_; bolt::BoltConnection* bolt_conn_ = nullptr; + // for plan cache + std::vector query_params_; + std::string param_query_; RTContext() = default; diff --git a/src/cypher/execution_plan/scheduler.cpp b/src/cypher/execution_plan/scheduler.cpp index 6b99007c00..0d959ac300 100644 --- a/src/cypher/execution_plan/scheduler.cpp +++ b/src/cypher/execution_plan/scheduler.cpp @@ -15,7 +15,11 @@ // // Created by wt on 18-8-14. // + + #include "./antlr4-runtime.h" +#include "execution_plan/plan_cache/plan_cache_param.h" + #include "geax-front-end/ast/AstNode.h" #include "geax-front-end/ast/AstDumper.h" #include "geax-front-end/isogql/GQLResolveCtx.h" @@ -34,6 +38,7 @@ #include "cypher/execution_plan/execution_plan.h" #include "cypher/execution_plan/scheduler.h" #include "cypher/execution_plan/execution_plan_v2.h" +#include "cypher/execution_plan/lru_cache.h" #include "cypher/rewriter/GenAnonymousAliasRewriter.h" #include "cypher/rewriter/MultiPathPatternRewriter.h" #include "cypher/rewriter/PushDownFilterAstRewriter.h" @@ -69,11 +74,15 @@ void Scheduler::EvalCypher(RTContext *ctx, const std::string &script, ElapsedTim using namespace parser; using namespace antlr4; auto t0 = fma_common::GetTime(); - // - thread_local LRUCacheThreadUnsafe> tls_plan_cache; std::shared_ptr plan; - if (!tls_plan_cache.Get(script, plan)) { - ANTLRInputStream input(script); + plan = std::make_shared(); + std::vector stmt; + parser::CmdType cmd; + ASTCacheObj cache_val; + // parameterize the query + std::string param_query = fastQueryParam(ctx, script); + if (!plan_cache_.get_plan(param_query, cache_val)) { + ANTLRInputStream input(param_query); LcypherLexer lexer(&input); CommonTokenStream tokens(&lexer); LcypherParser parser(&tokens); @@ -86,49 +95,54 @@ void Scheduler::EvalCypher(RTContext *ctx, const std::string &script, ElapsedTim for (const auto &sql_query : visitor.GetQuery()) { LOG_DEBUG() << sql_query.ToString(); } - plan = std::make_shared(); plan->PreValidate(ctx, visitor.GetNodeProperty(), visitor.GetRelProperty()); - plan->Build(visitor.GetQuery(), visitor.CommandType(), ctx); - plan->Validate(ctx); - if (plan->CommandType() != parser::CmdType::QUERY) { - ctx->result_info_ = std::make_unique(); - ctx->result_ = std::make_unique(); - std::string header, data; - if (plan->CommandType() == parser::CmdType::EXPLAIN) { - header = "@plan"; - data = plan->DumpPlan(0, false); - } else { - header = "@profile"; - data = plan->DumpGraph(); - } - ctx->result_->ResetHeader({{header, lgraph_api::LGraphType::STRING}}); - auto r = ctx->result_->MutableRecord(); - r->Insert(header, lgraph::FieldData(data)); - if (ctx->bolt_conn_) { - auto session = (bolt::BoltSession *)ctx->bolt_conn_->GetContext(); - ctx->result_->MarkPythonDriver(session->python_driver); - while (!session->streaming_msg) { - session->streaming_msg = session->msgs.Pop(std::chrono::milliseconds(100)); - if (ctx->bolt_conn_->has_closed()) { - LOG_INFO() << "The bolt connection is closed, cancel the op execution."; - return; - } - } - std::unordered_map meta; - meta["fields"] = ctx->result_->BoltHeader(); - bolt::PackStream ps; - ps.AppendSuccess(meta); - if (session->streaming_msg.value().type == bolt::BoltMsg::PullN) { - ps.AppendRecords(ctx->result_->BoltRecords()); - } else if (session->streaming_msg.value().type == bolt::BoltMsg::DiscardN) { - // ... + stmt = visitor.GetQuery(); + cmd = visitor.CommandType(); + plan_cache_.add_plan(param_query, ASTCacheObj(stmt, cmd)); + } else { + ASTCacheObj ast(cache_val); + stmt = ast.Stmt(); + cmd = ast.CmdType(); + } + plan->Build(stmt, cmd, ctx); + plan->Validate(ctx); + if (plan->CommandType() != parser::CmdType::QUERY) { + ctx->result_info_ = std::make_unique(); + ctx->result_ = std::make_unique(); + std::string header, data; + if (plan->CommandType() == parser::CmdType::EXPLAIN) { + header = "@plan"; + data = plan->DumpPlan(0, false); + } else { + header = "@profile"; + data = plan->DumpGraph(); + } + ctx->result_->ResetHeader({{header, lgraph_api::LGraphType::STRING}}); + auto r = ctx->result_->MutableRecord(); + r->Insert(header, lgraph::FieldData(data)); + if (ctx->bolt_conn_) { + auto session = (bolt::BoltSession *)ctx->bolt_conn_->GetContext(); + ctx->result_->MarkPythonDriver(session->python_driver); + while (!session->streaming_msg) { + session->streaming_msg = session->msgs.Pop(std::chrono::milliseconds(100)); + if (ctx->bolt_conn_->has_closed()) { + LOG_INFO() << "The bolt connection is closed, cancel the op execution."; + return; } - ps.AppendSuccess(); - ctx->bolt_conn_->PostResponse(std::move(ps.MutableBuffer())); } - return; + std::unordered_map meta; + meta["fields"] = ctx->result_->BoltHeader(); + bolt::PackStream ps; + ps.AppendSuccess(meta); + if (session->streaming_msg.value().type == bolt::BoltMsg::PullN) { + ps.AppendRecords(ctx->result_->BoltRecords()); + } else if (session->streaming_msg.value().type == bolt::BoltMsg::DiscardN) { + // ... + } + ps.AppendSuccess(); + ctx->bolt_conn_->PostResponse(std::move(ps.MutableBuffer())); } - LOG_DEBUG() << "Plan cache disabled."; + return; } LOG_DEBUG() << plan->DumpPlan(0, false); LOG_DEBUG() << plan->DumpGraph(); diff --git a/src/cypher/execution_plan/scheduler.h b/src/cypher/execution_plan/scheduler.h index bd973f77d2..46ee4d6c1e 100644 --- a/src/cypher/execution_plan/scheduler.h +++ b/src/cypher/execution_plan/scheduler.h @@ -28,7 +28,8 @@ #include "execution_plan/execution_plan.h" #include "execution_plan/runtime_context.h" -#include "cypher/execution_plan/lru_cache.h" +#include "execution_plan/plan_cache/plan_cache.h" + namespace lgraph { class StateMachine; @@ -63,5 +64,7 @@ class Scheduler { static bool DetermineGqlReadOnly(cypher::RTContext *ctx, const std::string &script, std::string &name, std::string &type); + + ASTCache plan_cache_; }; } // namespace cypher diff --git a/src/cypher/graph/common.h b/src/cypher/graph/common.h index 2561a3b12c..681a12a0dc 100644 --- a/src/cypher/graph/common.h +++ b/src/cypher/graph/common.h @@ -33,7 +33,7 @@ struct Property { lgraph::FieldData value; enum { NUL, // empty - PARAMETER, // {name:$name} + PARAMETER, // {name:$name} || $1 (for plan cache) VALUE, // {name:'Tom Hanks'} VARIABLE, // UNWIND [1,2] AS mid MATCH (n {id:mid}) || WITH {a: 1, b: 2} as pair } type; diff --git a/src/cypher/parser/clause.h b/src/cypher/parser/clause.h index fbffd5d3c6..dd1f53a256 100644 --- a/src/cypher/parser/clause.h +++ b/src/cypher/parser/clause.h @@ -441,6 +441,8 @@ static lgraph::FieldData MakeFieldData(const Expression &expr) { break; } case Expression::PARAMETER: + // Both plan cache parameters and named parameters are represented as strings. + ld = lgraph::FieldData(expr.String()); break; case Expression::NULL_: break; diff --git a/src/cypher/parser/cypher_base_visitor.h b/src/cypher/parser/cypher_base_visitor.h index 875a5c34d9..7cb8d61193 100644 --- a/src/cypher/parser/cypher_base_visitor.h +++ b/src/cypher/parser/cypher_base_visitor.h @@ -82,6 +82,14 @@ class CypherBaseVisitor : public LcypherVisitor { * MATCH (n) RETURN exists((n)-->()-->()) */ size_t _anonymous_idx = 0; + void FillParam() { + for (auto& part : _query) { + for (auto& p : part.parts) { + p.symbol_table.param_tab_ = &ctx_->param_tab_; + } + } + } + std::string GenAnonymousAlias(bool is_node) { std::string alias(ANONYMOUS); if (is_node) { @@ -95,6 +103,9 @@ class CypherBaseVisitor : public LcypherVisitor { bool AddSymbol(const std::string &symbol_alias, cypher::SymbolNode::Type type, cypher::SymbolNode::Scope scope) { + if (symbol_alias[0] == '$' && std::isdigit(symbol_alias[1])) { + return false; + } if (_InClauseRETURN() || (_InClauseWHERE() && !symbol_alias.empty() && symbol_alias[0] != '$')) { // TODO(anyone): more situations @@ -182,6 +193,7 @@ class CypherBaseVisitor : public LcypherVisitor { _curr_part = 0; _symbol_idx = 0; _anonymous_idx = 0; + FillParam(); visit(ctx->oC_SingleQuery()); for (auto u : ctx->oC_Union()) { // initialize for the next single_query @@ -206,6 +218,7 @@ class CypherBaseVisitor : public LcypherVisitor { int part_num = ctx->oC_MultiPartQuery()->oC_With().size() + 1; _query[_curr_query].parts.resize(part_num); } + FillParam(); return visitChildren(ctx); } @@ -463,6 +476,7 @@ class CypherBaseVisitor : public LcypherVisitor { _curr_part = 0; _symbol_idx = 0; _anonymous_idx = 0; + FillParam(); std::tuple> invocation; if (ctx->oC_ImplicitProcedureInvocation()) { diff --git a/src/cypher/parser/symbol_table.h b/src/cypher/parser/symbol_table.h index 1464d7f640..6e3a271ba6 100644 --- a/src/cypher/parser/symbol_table.h +++ b/src/cypher/parser/symbol_table.h @@ -77,6 +77,7 @@ struct AnnotationCollection { struct SymbolTable { std::unordered_map symbols; AnnotationCollection anot_collection; + PARAM_TAB* param_tab_ = nullptr; void DumpTable() const; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a11e748daf..0a6bb818ee 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -88,6 +88,7 @@ add_executable(unit_test test_perf_kv_fatkey.cpp test_perf_multi_writer.cpp test_perf_unaligned.cpp + test_plan_cache.cpp test_proto_convert.cpp test_python_plugin_manager.cpp test_python_plugin_manager_impl.cpp diff --git a/test/test_plan_cache.cpp b/test/test_plan_cache.cpp new file mode 100644 index 0000000000..16db0a2c1e --- /dev/null +++ b/test/test_plan_cache.cpp @@ -0,0 +1,59 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./antlr4-runtime.h" +#include "cypher/execution_plan/plan_cache/plan_cache_param.h" +#include "cypher/execution_plan/plan_cache/plan_cache.h" + +#include "gtest/gtest.h" +#include "./ut_utils.h" +#include "core/data_type.h" +#include "./test_tools.h" + +class TestPlanCache : public TuGraphTest {}; + +TEST_F(TestPlanCache, basicCaching) { + cypher::LRUPlanCache cache(512); + + cache.add_plan("1", 1); + int value; + cache.get_plan("1", value); + ASSERT_EQ(value, 1); + + cache.add_plan("2", 2); + cache.get_plan("2", value); + ASSERT_EQ(value, 2); + + ASSERT_EQ(cache.current_size(), 2); +} + +TEST_F(TestPlanCache, eviction) { + cypher::LRUPlanCache cache(512); + + for (int i = 0; i < 522; i++) { + cache.add_plan(std::to_string(i), i); + } + + for (int i = 0; i < 10; i++) { + int val; + bool res = cache.get_plan(std::to_string(i), val); + ASSERT_EQ(res, false); + } + + for (int i = 10; i < 522; i++) { + int val; + bool res = cache.get_plan(std::to_string(i), val); + ASSERT_EQ(res, true); + } +} From f1dbabf203e4a937db55854389582f4883098dae Mon Sep 17 00:00:00 2001 From: Myrrolinz Date: Tue, 29 Oct 2024 00:54:50 -0400 Subject: [PATCH 10/12] basic columnar-based data structure (#682) * basic columnar-based data structure * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType --------- Co-authored-by: Shipeng Qi --- src/cypher/resultset/bit_mask.h | 221 +++++++++++++++++ src/cypher/resultset/column_vector.h | 325 +++++++++++++++++++++++++ src/cypher/resultset/cypher_string_t.h | 66 +++++ 3 files changed, 612 insertions(+) create mode 100644 src/cypher/resultset/bit_mask.h create mode 100644 src/cypher/resultset/column_vector.h create mode 100644 src/cypher/resultset/cypher_string_t.h diff --git a/src/cypher/resultset/bit_mask.h b/src/cypher/resultset/bit_mask.h new file mode 100644 index 0000000000..74f985df7e --- /dev/null +++ b/src/cypher/resultset/bit_mask.h @@ -0,0 +1,221 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include + +namespace cypher { + +constexpr uint64_t BITMASKS_SINGLE_ONE[64] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000, 0x20000, 0x40000, 0x80000, + 0x100000, 0x200000, 0x400000, 0x800000, 0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, + 0x20000000, 0x40000000, 0x80000000, 0x100000000, 0x200000000, 0x400000000, 0x800000000, + 0x1000000000, 0x2000000000, 0x4000000000, 0x8000000000, 0x10000000000, 0x20000000000, + 0x40000000000, 0x80000000000, 0x100000000000, 0x200000000000, 0x400000000000, 0x800000000000, + 0x1000000000000, 0x2000000000000, 0x4000000000000, 0x8000000000000, 0x10000000000000, + 0x20000000000000, 0x40000000000000, 0x80000000000000, 0x100000000000000, 0x200000000000000, + 0x400000000000000, 0x800000000000000, 0x1000000000000000, 0x2000000000000000, + 0x4000000000000000, 0x8000000000000000}; +constexpr uint64_t BITMASKS_SINGLE_ZERO[64] = {0xfffffffffffffffe, 0xfffffffffffffffd, + 0xfffffffffffffffb, 0xfffffffffffffff7, 0xffffffffffffffef, 0xffffffffffffffdf, + 0xffffffffffffffbf, 0xffffffffffffff7f, 0xfffffffffffffeff, 0xfffffffffffffdff, + 0xfffffffffffffbff, 0xfffffffffffff7ff, 0xffffffffffffefff, 0xffffffffffffdfff, + 0xffffffffffffbfff, 0xffffffffffff7fff, 0xfffffffffffeffff, 0xfffffffffffdffff, + 0xfffffffffffbffff, 0xfffffffffff7ffff, 0xffffffffffefffff, 0xffffffffffdfffff, + 0xffffffffffbfffff, 0xffffffffff7fffff, 0xfffffffffeffffff, 0xfffffffffdffffff, + 0xfffffffffbffffff, 0xfffffffff7ffffff, 0xffffffffefffffff, 0xffffffffdfffffff, + 0xffffffffbfffffff, 0xffffffff7fffffff, 0xfffffffeffffffff, 0xfffffffdffffffff, + 0xfffffffbffffffff, 0xfffffff7ffffffff, 0xffffffefffffffff, 0xffffffdfffffffff, + 0xffffffbfffffffff, 0xffffff7fffffffff, 0xfffffeffffffffff, 0xfffffdffffffffff, + 0xfffffbffffffffff, 0xfffff7ffffffffff, 0xffffefffffffffff, 0xffffdfffffffffff, + 0xffffbfffffffffff, 0xffff7fffffffffff, 0xfffeffffffffffff, 0xfffdffffffffffff, + 0xfffbffffffffffff, 0xfff7ffffffffffff, 0xffefffffffffffff, 0xffdfffffffffffff, + 0xffbfffffffffffff, 0xff7fffffffffffff, 0xfeffffffffffffff, 0xfdffffffffffffff, + 0xfbffffffffffffff, 0xf7ffffffffffffff, 0xefffffffffffffff, 0xdfffffffffffffff, + 0xbfffffffffffffff, 0x7fffffffffffffff}; +constexpr uint64_t LOWER_BITMASKS[65] = {0x0, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f, 0xff, 0x1ff, + 0x3ff, 0x7ff, 0xfff, 0x1fff, 0x3fff, 0x7fff, 0xffff, 0x1ffff, 0x3ffff, 0x7ffff, 0xfffff, + 0x1fffff, 0x3fffff, 0x7fffff, 0xffffff, 0x1ffffff, 0x3ffffff, 0x7ffffff, 0xfffffff, 0x1fffffff, + 0x3fffffff, 0x7fffffff, 0xffffffff, 0x1ffffffff, 0x3ffffffff, 0x7ffffffff, 0xfffffffff, + 0x1fffffffff, 0x3fffffffff, 0x7fffffffff, 0xffffffffff, 0x1ffffffffff, 0x3ffffffffff, + 0x7ffffffffff, 0xfffffffffff, 0x1fffffffffff, 0x3fffffffffff, 0x7fffffffffff, 0xffffffffffff, + 0x1ffffffffffff, 0x3ffffffffffff, 0x7ffffffffffff, 0xfffffffffffff, 0x1fffffffffffff, + 0x3fffffffffffff, 0x7fffffffffffff, 0xffffffffffffff, 0x1ffffffffffffff, 0x3ffffffffffffff, + 0x7ffffffffffffff, 0xfffffffffffffff, 0x1fffffffffffffff, 0x3fffffffffffffff, + 0x7fffffffffffffff, 0xffffffffffffffff}; +constexpr uint64_t HIGH_BITMASKS[65] = {0x0, 0x8000000000000000, 0xc000000000000000, + 0xe000000000000000, 0xf000000000000000, 0xf800000000000000, 0xfc00000000000000, + 0xfe00000000000000, 0xff00000000000000, 0xff80000000000000, 0xffc0000000000000, + 0xffe0000000000000, 0xfff0000000000000, 0xfff8000000000000, 0xfffc000000000000, + 0xfffe000000000000, 0xffff000000000000, 0xffff800000000000, 0xffffc00000000000, + 0xffffe00000000000, 0xfffff00000000000, 0xfffff80000000000, 0xfffffc0000000000, + 0xfffffe0000000000, 0xffffff0000000000, 0xffffff8000000000, 0xffffffc000000000, + 0xffffffe000000000, 0xfffffff000000000, 0xfffffff800000000, 0xfffffffc00000000, + 0xfffffffe00000000, 0xffffffff00000000, 0xffffffff80000000, 0xffffffffc0000000, + 0xffffffffe0000000, 0xfffffffff0000000, 0xfffffffff8000000, 0xfffffffffc000000, + 0xfffffffffe000000, 0xffffffffff000000, 0xffffffffff800000, 0xffffffffffc00000, + 0xffffffffffe00000, 0xfffffffffff00000, 0xfffffffffff80000, 0xfffffffffffc0000, + 0xfffffffffffe0000, 0xffffffffffff0000, 0xffffffffffff8000, 0xffffffffffffc000, + 0xffffffffffffe000, 0xfffffffffffff000, 0xfffffffffffff800, 0xfffffffffffffc00, + 0xfffffffffffffe00, 0xffffffffffffff00, 0xffffffffffffff80, 0xffffffffffffffc0, + 0xffffffffffffffe0, 0xfffffffffffffff0, 0xfffffffffffffff8, 0xfffffffffffffffc, + 0xfffffffffffffffe, 0xffffffffffffffff}; + +class BitMask { + public: + static constexpr uint64_t NO_NULL_ENTRY = 0; + static constexpr uint64_t ALL_NULL_ENTRY = ~uint64_t(NO_NULL_ENTRY); + static constexpr uint64_t BITS_PER_ENTRY_LOG2 = 6; // 64 bits per entry + static constexpr uint64_t BITS_PER_ENTRY = (uint64_t)1 << BITS_PER_ENTRY_LOG2; + static constexpr uint64_t BYTES_PER_ENTRY = BITS_PER_ENTRY >> 3; // 8 bytes per entry + + explicit BitMask(uint64_t capacity) : may_contain_nulls_{false} { + auto num_null_entries = (capacity + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY; + buffer_ = std::make_unique(num_null_entries); + data_ = buffer_.get(); + size_ = num_null_entries; + std::fill(data_, data_ + num_null_entries, NO_NULL_ENTRY); + } + + explicit BitMask(uint64_t* mask_data, size_t size, bool may_contain_nulls) + : data_{mask_data}, size_{size}, buffer_{nullptr}, may_contain_nulls_{may_contain_nulls} {} + BitMask(const BitMask& other) { + size_ = other.size_; + may_contain_nulls_ = other.may_contain_nulls_; + buffer_ = std::make_unique(size_); + std::copy(other.data_, other.data_ + size_, buffer_.get()); + data_ = buffer_.get(); + } + + BitMask& operator=(const BitMask& other) { + if (this == &other) return *this; + size_ = other.size_; + may_contain_nulls_ = other.may_contain_nulls_; + buffer_ = std::make_unique(size_); + std::copy(other.data_, other.data_ + size_, buffer_.get()); + data_ = buffer_.get(); + return *this; + } + + void SetAllNonNull() { + if (!may_contain_nulls_) return; + std::fill(data_, data_ + size_, NO_NULL_ENTRY); + may_contain_nulls_ = false; + } + + void SetAllNull() { + std::fill(data_, data_ + size_, ALL_NULL_ENTRY); + may_contain_nulls_ = true; + } + + bool HasNoNullsGuarantee() const { return !may_contain_nulls_; } + + static void SetBit(uint64_t* entries, uint32_t pos, bool is_null) { + auto [entry_pos, bit_pos_in_entry] = GetEntryAndBitPos(pos); + if (is_null) { + entries[entry_pos] |= BITMASKS_SINGLE_ONE[bit_pos_in_entry]; + } else { + entries[entry_pos] &= BITMASKS_SINGLE_ZERO[bit_pos_in_entry]; + } + } + + void SetBit(uint32_t pos, bool is_null) { + SetBit(data_, pos, is_null); + if (is_null) { + may_contain_nulls_ = true; + } + } + + bool IsBitSet(uint32_t pos) const { + auto [entry_pos, bit_pos_in_entry] = GetEntryAndBitPos(pos); + return data_[entry_pos] & BITMASKS_SINGLE_ONE[bit_pos_in_entry]; + } + + const uint64_t* GetData() const { return data_; } + + static uint64_t GetNumEntries(uint64_t num_bits) { + return (num_bits >> BITS_PER_ENTRY_LOG2) + + ((num_bits - (num_bits << BITS_PER_ENTRY_LOG2)) == 0 ? 0 : 1); + } + + void resize(uint64_t capacity) { + auto num_entries = (capacity + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY; + auto resized_buffer = std::make_unique(num_entries); + if (data_) { + std::memcpy(resized_buffer.get(), data_, + std::min(size_, num_entries) * sizeof(uint64_t)); + } + buffer_ = std::move(resized_buffer); + data_ = buffer_.get(); + size_ = num_entries; + } + + void SetNullFromRange(uint64_t offset, uint64_t num_bits_to_set, bool is_null) { + if (is_null) { + may_contain_nulls_ = true; + } + SetNullRange(data_, offset, num_bits_to_set, is_null); + } + + static void SetNullRange(uint64_t* null_entries, uint64_t offset, + uint64_t num_bits_to_set, bool is_null) { + auto [first_entry_pos, first_bit_pos] = GetEntryAndBitPos(offset); + auto [last_entry_pos, last_bit_pos] = GetEntryAndBitPos(offset + num_bits_to_set); + + if (last_entry_pos > first_entry_pos + 1) { + std::fill(null_entries + first_entry_pos + 1, null_entries + last_entry_pos, + is_null ? ALL_NULL_ENTRY : NO_NULL_ENTRY); + } + + if (first_entry_pos == last_entry_pos) { + if (is_null) { + null_entries[first_entry_pos] |= (~LOWER_BITMASKS[first_bit_pos] + & ~HIGH_BITMASKS[BITS_PER_ENTRY - last_bit_pos]); + } else { + null_entries[first_entry_pos] &= (LOWER_BITMASKS[first_bit_pos] + | HIGH_BITMASKS[BITS_PER_ENTRY - last_bit_pos]); + } + } else { + if (is_null) { + null_entries[first_entry_pos] |= ~LOWER_BITMASKS[first_bit_pos]; + if (last_bit_pos > 0) { + null_entries[last_entry_pos] |= LOWER_BITMASKS[last_bit_pos]; + } + } else { + null_entries[first_entry_pos] &= LOWER_BITMASKS[first_bit_pos]; + if (last_bit_pos > 0) { + null_entries[last_entry_pos] &= ~LOWER_BITMASKS[last_bit_pos]; + } + } + } + } + + private: + static std::pair GetEntryAndBitPos(uint64_t pos) { + auto entry_pos = pos >> BITS_PER_ENTRY_LOG2; + return {entry_pos, pos - (entry_pos << BITS_PER_ENTRY_LOG2)}; + } + + private: + uint64_t* data_; + size_t size_; + std::unique_ptr buffer_; + bool may_contain_nulls_; +}; + +} // namespace cypher diff --git a/src/cypher/resultset/column_vector.h b/src/cypher/resultset/column_vector.h new file mode 100644 index 0000000000..b94dc914c9 --- /dev/null +++ b/src/cypher/resultset/column_vector.h @@ -0,0 +1,325 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include "cypher/resultset/bit_mask.h" +#include "cypher/resultset/cypher_string_t.h" + +namespace cypher { + +constexpr size_t DEFAULT_VECTOR_CAPACITY = 2048; + +class ColumnVector { + friend class StringColumn; + + public: + explicit ColumnVector(size_t element_size, size_t capacity = DEFAULT_VECTOR_CAPACITY, + lgraph_api::FieldType field_type = lgraph_api::FieldType::NUL) + : element_size_(element_size), + capacity_(capacity), + field_type_(field_type), + data_(new uint8_t[element_size * capacity]()), + bitmask_(capacity) {} + + ColumnVector(const ColumnVector& other) + : element_size_(other.element_size_), + capacity_(other.capacity_), + field_type_(other.field_type_), + data_(new uint8_t[other.element_size_ * other.capacity_]), + bitmask_(other.bitmask_) { + // Check if the ColumnVector contains strings + if (element_size_ == sizeof(cypher_string_t)) { + // Initialize overflow buffer + if (other.overflow_buffer_) { + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; // will update this as we copy strings + } + // Copy each cypher_string_t individually + for (uint32_t i = 0; i < capacity_; ++i) { + auto& src_str = reinterpret_cast(other.data_.get())[i]; + auto& dst_str = reinterpret_cast(data_.get())[i]; + dst_str.len = src_str.len; + if (cypher_string_t::IsShortString(src_str.len)) { + // Copy the short string directly + std::memcpy(dst_str.prefix, src_str.prefix, src_str.len); + } else { + // Copy the prefix + std::memcpy(dst_str.prefix, src_str.prefix, cypher_string_t::PREFIX_LENGTH); + // Allocate overflow space in the new overflow buffer + uint64_t overflow_size = src_str.len - cypher_string_t::PREFIX_LENGTH; + if (!overflow_buffer_) { + // Initialize overflow buffer if not already done + overflow_buffer_capacity_ = std::max(overflow_size, + static_cast(1024)); + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; + } else if (overflow_offset_ + overflow_size > overflow_buffer_capacity_) { + // Resize overflow buffer if necessary + ResizeOverflowBuffer(overflow_offset_ + overflow_size); + } + // Copy the overflow data + void* dst_overflow_ptr = overflow_buffer_.get() + overflow_offset_; + std::memcpy(dst_overflow_ptr, reinterpret_cast(src_str.overflowPtr), + overflow_size); + dst_str.overflowPtr = reinterpret_cast(dst_overflow_ptr); + overflow_offset_ += overflow_size; + } + } + } else { + // For non-string data, we can copy the data directly + std::memcpy(data_.get(), other.data_.get(), element_size_ * capacity_); + /* Copy overflow buffer if it exists (though for non-string data, it shouldn't). + Just in case for future use. */ + if (other.overflow_buffer_) { + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_offset_ = other.overflow_offset_; + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + std::memcpy(overflow_buffer_.get(), other.overflow_buffer_.get(), overflow_offset_); + } + } + } + + ColumnVector& operator=(const ColumnVector& other) { + if (this == &other) return *this; + element_size_ = other.element_size_; + capacity_ = other.capacity_; + field_type_ = other.field_type_; + data_ = std::unique_ptr(new uint8_t[other.element_size_ * other.capacity_]); + std::memcpy(data_.get(), other.data_.get(), other.element_size_ * other.capacity_); + bitmask_ = other.bitmask_; + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_offset_ = other.overflow_offset_; + if (other.overflow_buffer_) { + overflow_buffer_ = std::unique_ptr( + new uint8_t[other.overflow_buffer_capacity_]); + std::memcpy(overflow_buffer_.get(), other.overflow_buffer_.get(), overflow_offset_); + } else { + overflow_buffer_.reset(); + } + return *this; + } + + ~ColumnVector() = default; + + void SetAllNull() { bitmask_.SetAllNull(); } + void SetAllNonNull() { bitmask_.SetAllNonNull(); } + bool HasNoNullsGuarantee() const { return bitmask_.HasNoNullsGuarantee(); } + + void SetNullRange(uint32_t start, uint32_t len, bool value) { + bitmask_.SetNullFromRange(start, len, value); + } + + void SetNull(uint32_t pos, bool is_null) { bitmask_.SetBit(pos, is_null); } + + bool IsNull(uint32_t pos) const { return bitmask_.IsBitSet(pos); } + + uint8_t* data() const { return data_.get(); } + + uint32_t GetElementSize() const { return element_size_; } + + uint32_t GetCapacity() const { return capacity_; } + + lgraph_api::FieldType GetFieldType() const { return field_type_; } + + template + const T& GetValue(uint32_t pos) const { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + return reinterpret_cast(data_.get())[pos]; + } + + template + T& GetValue(uint32_t pos) { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + return reinterpret_cast(data_.get())[pos]; + } + + template + void SetValue(uint32_t pos, T val) { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + reinterpret_cast(data_.get())[pos] = val; + } + + void* AllocateOverflow(uint64_t size) const { + if (!overflow_buffer_) { + overflow_buffer_capacity_ = std::max(size, static_cast(1024)); + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; + } else if (overflow_offset_ + size > overflow_buffer_capacity_) { + uint64_t new_capacity = overflow_offset_ + size; + new_capacity = ((new_capacity + 1023) / 1024) * 1024; + ResizeOverflowBuffer(new_capacity); + } + void* ptr = overflow_buffer_.get() + overflow_offset_; + overflow_offset_ += size; + return ptr; + } + + // fetch field size + static size_t GetFieldSize(lgraph_api::FieldType type) { + switch (type) { + case lgraph_api::FieldType::BOOL: + return sizeof(bool); + case lgraph_api::FieldType::INT8: + return sizeof(int8_t); + case lgraph_api::FieldType::INT16: + return sizeof(int16_t); + case lgraph_api::FieldType::INT32: + return sizeof(int32_t); + case lgraph_api::FieldType::INT64: + return sizeof(int64_t); + case lgraph_api::FieldType::FLOAT: + return sizeof(float); + case lgraph_api::FieldType::DOUBLE: + return sizeof(double); + default: + throw std::runtime_error("Unsupported field type"); + } + } + + // insert data into column vector + static void InsertIntoColumnVector(ColumnVector* column_vector, + const lgraph_api::FieldData& field, + uint32_t pos) { + switch (field.type) { + case lgraph_api::FieldType::BOOL: + column_vector->SetValue(pos, field.AsBool()); + break; + case lgraph_api::FieldType::INT8: + column_vector->SetValue(pos, field.AsInt8()); + break; + case lgraph_api::FieldType::INT16: + column_vector->SetValue(pos, field.AsInt16()); + break; + case lgraph_api::FieldType::INT32: + column_vector->SetValue(pos, field.AsInt32()); + break; + case lgraph_api::FieldType::INT64: + column_vector->SetValue(pos, field.AsInt64()); + break; + case lgraph_api::FieldType::FLOAT: + column_vector->SetValue(pos, field.AsFloat()); + break; + case lgraph_api::FieldType::DOUBLE: + column_vector->SetValue(pos, field.AsDouble()); + break; + default: + throw std::runtime_error("Unsupported field type"); + } + } + + private: + void ResizeOverflowBuffer(uint64_t new_capacity) const { + if (new_capacity <= overflow_buffer_capacity_) return; + auto new_buffer = std::make_unique(new_capacity); + if (overflow_buffer_) { + std::memcpy(new_buffer.get(), overflow_buffer_.get(), overflow_offset_); + } + overflow_buffer_ = std::move(new_buffer); + overflow_buffer_capacity_ = new_capacity; + } + + private: + uint32_t element_size_; // size of each element in bytes + uint32_t capacity_; // number of elements + lgraph_api::FieldType field_type_; + std::unique_ptr data_; + BitMask bitmask_; + mutable uint64_t overflow_buffer_capacity_; + mutable std::unique_ptr overflow_buffer_ = nullptr; + mutable uint64_t overflow_offset_; +}; + + +class StringColumn { + public: + // add string to vector + static void AddString(ColumnVector* vector, uint32_t vectorPos, cypher_string_t& srcStr) { + auto& dstStr = vector->GetValue(vectorPos); + if (cypher_string_t::IsShortString(srcStr.len)) { + dstStr.SetShortString(reinterpret_cast(srcStr.prefix), srcStr.len); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(srcStr.len)); + dstStr.SetLongString(reinterpret_cast(srcStr.prefix), srcStr.len); + } + } + + static void AddString(ColumnVector* vector, uint32_t vectorPos, const char* srcStr, + uint64_t length) { + auto& dstStr = vector->GetValue(vectorPos); + if (cypher_string_t::IsShortString(length)) { + dstStr.SetShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + dstStr.SetLongString(srcStr, length); + } + } + + static void AddString(ColumnVector* vector, uint32_t vectorPos, const std::string& srcStr) { + AddString(vector, vectorPos, srcStr.data(), srcStr.length()); + } + + static cypher_string_t& ReserveString(ColumnVector* vector, uint32_t vectorPos, + uint64_t length) { + auto& dstStr = vector->GetValue(vectorPos); + dstStr.len = length; + if (!cypher_string_t::IsShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + } + return dstStr; + } + + static void ReserveString(ColumnVector* vector, cypher_string_t& dstStr, uint64_t length) { + dstStr.len = length; + if (!cypher_string_t::IsShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, cypher_string_t& srcStr) { + if (cypher_string_t::IsShortString(srcStr.len)) { + dstStr.SetShortString(reinterpret_cast(srcStr.prefix), srcStr.len); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(srcStr.len)); + dstStr.SetLongString(reinterpret_cast(srcStr.prefix), srcStr.len); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, const char* srcStr, + uint64_t length) { + if (cypher_string_t::IsShortString(length)) { + dstStr.SetShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + dstStr.SetLongString(srcStr, length); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, + const std::string& srcStr) { + AddString(vector, dstStr, srcStr.data(), srcStr.length()); + } +}; +} // namespace cypher diff --git a/src/cypher/resultset/cypher_string_t.h b/src/cypher/resultset/cypher_string_t.h new file mode 100644 index 0000000000..a07404f547 --- /dev/null +++ b/src/cypher/resultset/cypher_string_t.h @@ -0,0 +1,66 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include + +namespace cypher { + +struct cypher_string_t { + static constexpr uint64_t PREFIX_LENGTH = 4; + static constexpr uint64_t INLINED_SUFFIX_LENGTH = 8; + static constexpr uint64_t SHORT_STR_LENGTH = PREFIX_LENGTH + INLINED_SUFFIX_LENGTH; + + uint32_t len; + uint8_t prefix[PREFIX_LENGTH]; + union { + uint8_t data[INLINED_SUFFIX_LENGTH]; + uint64_t overflowPtr; + }; + + cypher_string_t() : len{0}, overflowPtr{0} {} + + static bool IsShortString(uint32_t len) { return len <= SHORT_STR_LENGTH; } + + void SetShortString(const char* value, uint64_t length) { + len = length; + std::memcpy(prefix, value, length); + } + + void SetLongString(const char* value, uint64_t length) { + len = length; + std::memcpy(prefix, value, PREFIX_LENGTH); + std::memcpy(reinterpret_cast(overflowPtr), value + PREFIX_LENGTH, + length - PREFIX_LENGTH); + } + + std::string GetAsShortString() const { + return std::string(reinterpret_cast(prefix), len); + } + + std::string GetAsString() const { + if (IsShortString(len)) { + return std::string(reinterpret_cast(prefix), len); + } else { + return std::string(reinterpret_cast(prefix), PREFIX_LENGTH) + + std::string(reinterpret_cast(overflowPtr), len - PREFIX_LENGTH); + } + } +}; + +} // namespace cypher From 449770de8a2c48ac1534c6ca904221cc4f38a2d8 Mon Sep 17 00:00:00 2001 From: Myrrolinz Date: Tue, 29 Oct 2024 02:53:31 -0400 Subject: [PATCH 11/12] Column record (#683) * basic columnar-based data structure * columnar data record structure * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType * support FieldType * format fix --------- Co-authored-by: yannan-wyn <129476350+yannan-wyn@users.noreply.github.com> --- src/cypher/resultset/record.h | 366 ++++++++++++++++++++++++++++++++++ 1 file changed, 366 insertions(+) diff --git a/src/cypher/resultset/record.h b/src/cypher/resultset/record.h index 2131160e99..9639083235 100644 --- a/src/cypher/resultset/record.h +++ b/src/cypher/resultset/record.h @@ -18,11 +18,13 @@ #pragma once #include +#include #include "core/data_type.h" // lgraph::FieldData #include "cypher/cypher_types.h" #include "parser/data_typedef.h" #include "graph/node.h" #include "graph/relationship.h" +#include "cypher/resultset/column_vector.h" namespace cypher { @@ -321,4 +323,368 @@ struct Record { return *this; } }; + +struct DataChunk { + std::unordered_map> columnar_data_; + std::unordered_map> string_columns_; + std::unordered_map property_positions_; + std::unordered_map> property_vids_; + + DataChunk() = default; + + void CopyColumn(const std::string& column_name, const DataChunk& source_record, + bool overwrite_existing = true) { + if (source_record.columnar_data_.find(column_name) != source_record.columnar_data_.end()) { + const auto& src_vector = source_record.columnar_data_.at(column_name); + if (overwrite_existing || columnar_data_.find(column_name) == columnar_data_.end()) { + auto dst_vector = std::make_unique(*src_vector); + columnar_data_[column_name] = std::move(dst_vector); + property_positions_[column_name] = + source_record.property_positions_.at(column_name); + property_vids_[column_name] = + source_record.property_vids_.at(column_name); + } + } else if (source_record.string_columns_.find(column_name) != + source_record.string_columns_.end()) { + const auto& src_vector = source_record.string_columns_.at(column_name); + if (overwrite_existing || string_columns_.find(column_name) == + string_columns_.end()) { + auto dst_vector = std::make_unique(*src_vector); + string_columns_[column_name] = std::move(dst_vector); + property_positions_[column_name] = + source_record.property_positions_.at(column_name); + property_vids_[column_name] = source_record.property_vids_.at(column_name); + } + } + } + + void MergeColumn(const DataChunk& source_record, bool overwrite_existing = true) { + for (const auto& pair : source_record.columnar_data_) { + CopyColumn(pair.first, source_record, overwrite_existing); + } + for (const auto& pair : source_record.string_columns_) { + CopyColumn(pair.first, source_record, overwrite_existing); + } + } + + void TruncateData(int usable_r) { + std::set sorted_vids; + for (const auto& pair : property_vids_) { + const auto& vids = pair.second; + sorted_vids.insert(vids.begin(), vids.end()); + } + + std::vector selected_vids; + for (auto it = sorted_vids.begin(); it != sorted_vids.end() && + selected_vids.size() < static_cast(usable_r); ++it) { + selected_vids.push_back(*it); + } + for (auto& pair : columnar_data_) { + const std::string& column_name = pair.first; + auto& column_vector = pair.second; + const auto& vids = property_vids_.at(column_name); + auto new_vector = std::make_unique( + column_vector->GetElementSize(), usable_r, + column_vector->GetFieldType()); + std::vector new_vids; + uint32_t new_pos = 0; + for (uint32_t selected_vid : selected_vids) { + auto it = std::find(vids.begin(), vids.end(), selected_vid); + if (it != vids.end()) { + uint32_t original_pos = std::distance(vids.begin(), it); + if (!column_vector->IsNull(original_pos)) { + std::memcpy(new_vector->data() + new_pos * new_vector->GetElementSize(), + column_vector->data() + original_pos * + column_vector->GetElementSize(), + column_vector->GetElementSize()); + new_vids.push_back(selected_vid); + } + new_pos++; + } + } + column_vector = std::move(new_vector); + property_vids_[column_name] = new_vids; + property_positions_[column_name] = new_pos; + } + + for (auto& pair : string_columns_) { + const std::string& column_name = pair.first; + auto& column_vector = pair.second; + const auto& vids = property_vids_.at(column_name); + auto new_vector = std::make_unique( + sizeof(cypher_string_t), usable_r, column_vector->GetFieldType()); + std::vector new_vids; + uint32_t new_pos = 0; + for (uint32_t selected_vid : selected_vids) { + auto it = std::find(vids.begin(), vids.end(), selected_vid); + if (it != vids.end()) { + uint32_t original_pos = std::distance(vids.begin(), it); + if (!column_vector->IsNull(original_pos)) { + StringColumn::AddString(new_vector.get(), new_pos, + column_vector->GetValue(original_pos).GetAsString()); + new_vids.push_back(selected_vid); + } + new_pos++; + } + } + column_vector = std::move(new_vector); + property_vids_[column_name] = new_vids; + property_positions_[column_name] = new_pos; + } + } + + void Append(const DataChunk& source_record) { + for (const auto& pair : source_record.string_columns_) { + const std::string& column_name = pair.first; + const auto& src_vector = pair.second; + uint32_t size = source_record.property_positions_.at(column_name); + // PrintStringColumnData(column_name, *src_vector); + if (string_columns_.find(column_name) == string_columns_.end()) { + string_columns_[column_name] = std::make_unique(*src_vector); + property_positions_[column_name] = + source_record.property_positions_.at(column_name); + property_vids_[column_name] = source_record.property_vids_.at(column_name); + } else { + auto& dst_vector = string_columns_[column_name]; + uint32_t old_size = property_vids_[column_name].size(); + uint32_t new_size = old_size + size; + auto new_vector = std::make_unique( + sizeof(cypher_string_t), new_size, dst_vector->GetFieldType()); + for (uint32_t i = 0; i < old_size; ++i) { + const cypher_string_t& dst_string = dst_vector->GetValue(i); + StringColumn::AddString(new_vector.get(), i, dst_string.GetAsString()); + } + for (uint32_t i = 0; i < size; ++i) { + const cypher_string_t& src_string = src_vector->GetValue(i); + StringColumn::AddString(new_vector.get(), + old_size + i, src_string.GetAsString()); + } + string_columns_[column_name] = std::move(new_vector); + property_positions_[column_name] += + source_record.property_positions_.at(column_name); + auto& dst_vids = property_vids_[column_name]; + const auto& src_vids = source_record.property_vids_.at(column_name); + dst_vids.insert(dst_vids.end(), src_vids.begin(), src_vids.end()); + } + } + + for (const auto& pair : source_record.columnar_data_) { + const std::string& column_name = pair.first; + const auto& src_vector = pair.second; + auto size = source_record.property_positions_.at(column_name); + if (columnar_data_.find(column_name) == columnar_data_.end()) { + columnar_data_[column_name] = std::make_unique(*src_vector); + property_positions_[column_name] = + source_record.property_positions_.at(column_name); + property_vids_[column_name] = source_record.property_vids_.at(column_name); + } else { + auto& dst_vector = columnar_data_[column_name]; + uint32_t old_size = property_vids_[column_name].size(); + uint32_t new_size = old_size + size; + auto new_vector = std::make_unique( + dst_vector->GetElementSize(), new_size, + dst_vector->GetFieldType()); + std::memcpy(new_vector->data(), dst_vector->data(), + dst_vector->GetElementSize() * old_size); + std::memcpy(new_vector->data() + old_size * new_vector->GetElementSize(), + src_vector->data(), src_vector->GetElementSize() * size); + columnar_data_[column_name] = std::move(new_vector); + property_positions_[column_name] += + source_record.property_positions_.at(column_name); + auto& dst_vids = property_vids_[column_name]; + const auto& src_vids = source_record.property_vids_.at(column_name); + dst_vids.insert(dst_vids.end(), src_vids.begin(), src_vids.end()); + } + } + } + + void Print() const { + std::cout << "DataChunk contents:\n"; + std::cout << "Columnar Data:\n"; + for (const auto& pair : columnar_data_) { + std::cout << " Column Name: " << pair.first << "\n"; + std::cout << " Element Size: " << pair.second->GetElementSize() << "\n"; + std::cout << " Capacity: " << pair.second->GetCapacity() << "\n"; + PrintColumnData(pair.first, *pair.second); + } + + std::cout << "String Columns:\n"; + for (const auto& pair : string_columns_) { + std::cout << " Column Name: " << pair.first << "\n"; + std::cout << " Element Size: " << pair.second->GetElementSize() << "\n"; + std::cout << " Capacity: " << pair.second->GetCapacity() << "\n"; + PrintStringColumnData(pair.first, *pair.second); + } + + std::cout << "Property Positions:\n"; + for (const auto& pair : property_positions_) { + std::cout << " Property Name: " << pair.first + << ", End Position: " << pair.second << "\n"; + } + + std::cout << "Property VIDs:\n"; + for (const auto& pair : property_vids_) { + std::cout << " Property Name: " << pair.first << ", VIDs: ["; + const auto& vids = pair.second; + for (size_t i = 0; i < vids.size(); ++i) { + std::cout << vids[i]; + if (i < vids.size() - 1) { + std::cout << ", "; + } + } + std::cout << "]\n"; + } + } + + template + void PrintColumnData(const std::string& column_name, const ColumnVector& column) const { + for (uint32_t i = 0; i < column.GetCapacity(); ++i) { + if (!column.IsNull(i)) { + std::cout << " Data[" << i << "]: " << column.GetValue(i) << "\n"; + } else { + std::cout << " Data[" << i << "]: NULL\n"; + } + } + } + + void PrintColumnData(const std::string& column_name, const ColumnVector& column) const { + lgraph_api::FieldType field_type = column.GetFieldType(); + switch (field_type) { + case lgraph_api::FieldType::BOOL: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::INT8: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::INT16: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::INT32: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::INT64: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::FLOAT: + PrintColumnData(column_name, column); + break; + case lgraph_api::FieldType::DOUBLE: + PrintColumnData(column_name, column); + break; + default: + std::cout << "Unsupported field type for column: " << column_name << "\n"; + break; + } + } + + void PrintStringColumnData(const std::string& column_name, const ColumnVector& column) const { + for (uint32_t i = 0; i < column.GetCapacity(); ++i) { + if (!column.IsNull(i)) { + const cypher_string_t& value = column.GetValue(i); + std::cout << " Data[" << i << "]: " << value.GetAsString() << "\n"; + } else { + std::cout << " Data[" << i << "]: NULL\n"; + } + } + } + + std::string Dump(bool is_standard) const { + json arr = json::array(); + std::vector column_names; + for (const auto& pair : columnar_data_) { + column_names.push_back(pair.first); + } + for (const auto& pair : string_columns_) { + if (std::find(column_names.begin(), column_names.end(), + pair.first) == column_names.end()) { + column_names.push_back(pair.first); + } + } + + std::set common_vids; + for (const auto& pair : property_vids_) { + const auto& vids = pair.second; + common_vids.insert(vids.begin(), vids.end()); + } + // for (uint32_t vid : common_vids) { + // std::cout << "Common VID: " << vid << "\n"; + // } + for (uint32_t vid : common_vids) { + json j; + j["vid"] = vid; + for (const auto& column_name : column_names) { + if (property_vids_.find(column_name) != property_vids_.end()) { + const auto& vids = property_vids_.at(column_name); + auto it = std::find(vids.begin(), vids.end(), vid); + if (it != vids.end()) { + uint32_t column_pos = std::distance(vids.begin(), it); + if (columnar_data_.find(column_name) != columnar_data_.end()) { + const auto& column_vector = columnar_data_.at(column_name); + if (!column_vector->IsNull(column_pos)) { + lgraph_api::FieldType field_type = column_vector->GetFieldType(); + switch (field_type) { + case lgraph_api::FieldType::BOOL: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::INT8: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::INT16: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::INT32: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::INT64: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::FLOAT: + j[column_name] = + column_vector->GetValue(column_pos); + break; + case lgraph_api::FieldType::DOUBLE: + j[column_name] = + column_vector->GetValue(column_pos); + break; + default: + throw std::runtime_error( + "Unsupported data type in columnar_data_"); + } + } + } + if (string_columns_.find(column_name) != string_columns_.end()) { + const auto& column_vector = string_columns_.at(column_name); + if (!column_vector->IsNull(column_pos)) { + j[column_name] = + column_vector->GetValue(column_pos).GetAsString(); + } + } + } + } + } + if (j.is_null()) { + throw std::runtime_error( + "DataChunk has a null row! Maybe your new record is not a reference."); + } + arr.emplace_back(j); + } + // std::cout << "Dump Result: " << arr.dump(4) << std::endl; + if (is_standard) { + json output; + output["header"] = column_names; + output["is_standard"] = true; + output["data"] = arr; + return output.dump(); + } else { + return arr.dump(); + } + } +}; + + } // namespace cypher From bc479471fef62a4636b4cdc4e896e0318164f287 Mon Sep 17 00:00:00 2001 From: spasserby <569078986@qq.com> Date: Wed, 30 Oct 2024 07:21:36 +0000 Subject: [PATCH 12/12] cpplint --- .../experimental/data_type/field_data.h | 101 +++++++++--------- src/cypher/experimental/data_type/record.h | 59 ++++++---- src/cypher/experimental/expressions/cexpr.cpp | 15 ++- src/cypher/experimental/expressions/cexpr.h | 31 ++++-- .../expressions/kernal/binary.cpp | 60 +++++++---- src/cypher/experimental/jit/TuJIT.cpp | 15 ++- src/cypher/experimental/jit/TuJIT.h | 39 +++++-- test/test_query_compilation.cpp | 39 ++++--- toolkits/lgraph_compilation.cpp | 44 +++++--- 9 files changed, 261 insertions(+), 142 deletions(-) diff --git a/src/cypher/experimental/data_type/field_data.h b/src/cypher/experimental/data_type/field_data.h index 8ebb22cf0a..f163715c83 100644 --- a/src/cypher/experimental/data_type/field_data.h +++ b/src/cypher/experimental/data_type/field_data.h @@ -1,15 +1,30 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ #pragma once -#include -#include + #include #include +#include +#include + #include "core/data_type.h" #include "cypher/cypher_types.h" #include "cypher/cypher_exception.h" -using builder::static_var; using builder::dyn_var; +using builder::static_var; using lgraph::FieldType; namespace cypher { @@ -17,70 +32,60 @@ namespace compilation { struct CScalarData { static constexpr const char* type_name = "CScalarData"; - std::variant< - std::monostate, // Represent the null state - dyn_var, - dyn_var, - dyn_var, - dyn_var, - dyn_var - > constant_; + std::variant, dyn_var, dyn_var, dyn_var, dyn_var> + constant_; lgraph::FieldType type_; - CScalarData() { - type_ = lgraph_api::FieldType::NUL; - } + CScalarData() { type_ = lgraph_api::FieldType::NUL; } - CScalarData(CScalarData &&data) - : constant_(std::move(data.constant_)), type_(data.type_) {} + CScalarData(CScalarData&& data) : constant_(std::move(data.constant_)), type_(data.type_) {} - CScalarData(const CScalarData& other) - : constant_(other.constant_), type_(other.type_) {} + CScalarData(const CScalarData& other) : constant_(other.constant_), type_(other.type_) {} - CScalarData(const lgraph::FieldData& other) { + explicit CScalarData(const lgraph::FieldData& other) { type_ = other.type; switch (other.type) { case lgraph::FieldType::NUL: constant_.emplace(); break; case lgraph::FieldType::INT64: - constant_.emplace>((long)other.integer()); + constant_.emplace>((int64_t)other.integer()); break; default: CYPHER_TODO(); } } - explicit CScalarData(long integer) { - constant_.emplace>(integer); + explicit CScalarData(int64_t integer) { + constant_.emplace>(integer); type_ = lgraph::FieldType::INT64; } - - explicit CScalarData(const static_var &integer) - : type_(FieldType::INT64) { - constant_ = (dyn_var) integer; + + explicit CScalarData(const static_var& integer) : type_(FieldType::INT64) { + constant_ = (dyn_var)integer; } - explicit CScalarData(const dyn_var &integer) - : constant_(integer), type_(FieldType::INT64) {} + explicit CScalarData(const dyn_var& integer) + : constant_(integer), type_(FieldType::INT64) {} - explicit CScalarData(dyn_var&& integer) - : constant_(std::move(integer)), type_(FieldType::INT64) {} + explicit CScalarData(dyn_var&& integer) + : constant_(std::move(integer)), type_(FieldType::INT64) {} - inline dyn_var integer() const { + inline dyn_var integer() const { switch (type_) { case FieldType::NUL: case FieldType::BOOL: throw std::bad_cast(); case FieldType::INT8: - return std::get>(constant_); + return std::get>(constant_); case FieldType::INT16: - return std::get>(constant_); + return std::get>(constant_); case FieldType::INT32: return std::get>(constant_); case FieldType::INT64: - return std::get>(constant_); + return std::get>(constant_); case FieldType::FLOAT: case FieldType::DOUBLE: case FieldType::DATE: @@ -94,7 +99,7 @@ struct CScalarData { case FieldType::FLOAT_VECTOR: throw std::bad_cast(); } - return dyn_var(0); + return dyn_var(0); } inline dyn_var real() const { @@ -124,17 +129,11 @@ struct CScalarData { return dyn_var(0); } - dyn_var Int64() const { - return std::get>(constant_); - } + dyn_var Int64() const { return std::get>(constant_); } - inline bool is_integer() const { - return type_ >= FieldType::INT8 && type_ <= FieldType::INT64; - } + inline bool is_integer() const { return type_ >= FieldType::INT8 && type_ <= FieldType::INT64; } - inline bool is_real() const { - return type_ == FieldType::DOUBLE || type_ == FieldType::FLOAT; - } + inline bool is_real() const { return type_ == FieldType::DOUBLE || type_ == FieldType::FLOAT; } bool is_null() const { return type_ == lgraph::FieldType::NUL; } @@ -160,7 +159,7 @@ struct CScalarData { }; struct CFieldData { - enum FieldType { SCALAR, ARRAY, MAP} type; + enum FieldType { SCALAR, ARRAY, MAP } type; CScalarData scalar; std::vector* array = nullptr; @@ -168,11 +167,11 @@ struct CFieldData { CFieldData() : type(SCALAR) {} - CFieldData(const CFieldData &data) : type(data.type), scalar(data.scalar) {} + CFieldData(const CFieldData& data) : type(data.type), scalar(data.scalar) {} - CFieldData(const CScalarData& scalar) : type(SCALAR), scalar(scalar) {} + explicit CFieldData(const CScalarData& scalar) : type(SCALAR), scalar(scalar) {} - CFieldData(CScalarData&& scalar) : type(SCALAR), scalar(std::move(scalar)) {} + explicit CFieldData(CScalarData&& scalar) : type(SCALAR), scalar(std::move(scalar)) {} CFieldData& operator=(const CFieldData& data) { this->type = data.type; @@ -186,7 +185,7 @@ struct CFieldData { return *this; } - explicit CFieldData(const static_var& scalar) : type(SCALAR), scalar(scalar) {} + explicit CFieldData(const static_var& scalar) : type(SCALAR), scalar(scalar) {} bool is_null() const { return type == SCALAR && scalar.is_null(); } @@ -200,5 +199,5 @@ struct CFieldData { CFieldData operator-(const CFieldData& other) const; }; -} // namespace compilation -} // namespace cypher \ No newline at end of file +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/data_type/record.h b/src/cypher/experimental/data_type/record.h index bbce67fee4..352c7c91c6 100644 --- a/src/cypher/experimental/data_type/record.h +++ b/src/cypher/experimental/data_type/record.h @@ -1,3 +1,17 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + #pragma once #include @@ -38,43 +52,46 @@ struct CEntry { explicit CEntry(const cypher::Entry& entry) { switch (entry.type) { - case cypher::Entry::CONSTANT: { - constant_ = CScalarData(entry.constant.scalar); - type_ = CONSTANT; - break; - } - case cypher::Entry::NODE: { - node_ = entry.node; - type_ = NODE; - break; - } - case cypher::Entry::RELATIONSHIP: { - relationship_ = entry.relationship; - type_ = RELATIONSHIP; - break; - } + case cypher::Entry::CONSTANT: + { + constant_ = CScalarData(entry.constant.scalar); + type_ = CONSTANT; + break; + } + case cypher::Entry::NODE: + { + node_ = entry.node; + type_ = NODE; + break; + } + case cypher::Entry::RELATIONSHIP: + { + relationship_ = entry.relationship; + type_ = RELATIONSHIP; + break; + } default: CYPHER_TODO(); } } - explicit CEntry(const CFieldData &data) : constant_(data), type_(CONSTANT) {} + explicit CEntry(const CFieldData& data) : constant_(data), type_(CONSTANT) {} explicit CEntry(CFieldData&& data) : constant_(std::move(data)), type_(CONSTANT) {} explicit CEntry(const CScalarData& scalar) : constant_(scalar), type_(CONSTANT) {} }; -struct CRecord { // Should be derived from cypher::Record +struct CRecord { // Should be derived from cypher::Record std::vector values; - + CRecord() = default; - CRecord(const cypher::Record &record) { + explicit CRecord(const cypher::Record& record) { for (auto& entry : record.values) { values.emplace_back(entry); } } }; -} // namespace compilaiton -} // namespace cypher \ No newline at end of file +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/expressions/cexpr.cpp b/src/cypher/experimental/expressions/cexpr.cpp index 78ed26da91..91e19bd728 100644 --- a/src/cypher/experimental/expressions/cexpr.cpp +++ b/src/cypher/experimental/expressions/cexpr.cpp @@ -1,2 +1,15 @@ -#include "cexpr.h" +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/experimental/expressions/cexpr.h" diff --git a/src/cypher/experimental/expressions/cexpr.h b/src/cypher/experimental/expressions/cexpr.h index 395dea5f08..ec97d4a1d8 100644 --- a/src/cypher/experimental/expressions/cexpr.h +++ b/src/cypher/experimental/expressions/cexpr.h @@ -1,3 +1,19 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + #include #include @@ -20,9 +36,9 @@ #include "experimental/data_type/record.h" #include "cypher/execution_plan/runtime_context.h" -using builder::static_var; -using builder::dyn_var; using builder::defer_init; +using builder::dyn_var; +using builder::static_var; using builder::with_name; namespace cypher { @@ -56,7 +72,6 @@ namespace compilation { // type = AR_OPERAND_CONSTANT; // } - // inline CEntry Eval(const CRecord &record) { // if (type == AR_OPERAND_CONSTANT) { // return CEntry(constant); @@ -131,9 +146,7 @@ class ExprEvaluator : public geax::frontend::AstExprNodeVisitorImpl { } } - geax::frontend::Expr* GetExpression() { - return expr_; - } + geax::frontend::Expr* GetExpression() { return expr_; } private: std::any visit(geax::frontend::GetField* node) override; @@ -221,9 +234,9 @@ struct CExprNode { CExprNode() = default; - inline CEntry Eval(cypher::RTContext *ctx, const CRecord &record) { + inline CEntry Eval(cypher::RTContext* ctx, const CRecord& record) { return evaluator_->Evaluate(ctx, &record); } }; -} // namespace compilation -} // namespace cypher \ No newline at end of file +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/expressions/kernal/binary.cpp b/src/cypher/experimental/expressions/kernal/binary.cpp index 0f8504cd40..711809b334 100644 --- a/src/cypher/experimental/expressions/kernal/binary.cpp +++ b/src/cypher/experimental/expressions/kernal/binary.cpp @@ -1,3 +1,17 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + #include #include #include "cypher/cypher_types.h" @@ -11,7 +25,7 @@ namespace cypher { namespace compilation { -CFieldData CFieldData::operator+(const CFieldData &other) const { +CFieldData CFieldData::operator+(const CFieldData& other) const { if (is_null() || other.is_null()) return CFieldData(); CFieldData ret; if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { @@ -23,14 +37,15 @@ CFieldData CFieldData::operator+(const CFieldData &other) const { ret.scalar = CScalarData(scalar.Int64() + other.scalar.Int64()); } else { dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); - dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + dyn_var y_n = + is_integer() ? (dyn_var)other.scalar.integer() : other.scalar.real(); ret.scalar = std::move(CScalarData(x_n + y_n)); } } return ret; } -CFieldData CFieldData::operator-(const CFieldData &other) const { +CFieldData CFieldData::operator-(const CFieldData& other) const { if (is_null() || other.is_null()) return CFieldData(); CFieldData ret; if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { @@ -42,34 +57,33 @@ CFieldData CFieldData::operator-(const CFieldData &other) const { ret.scalar = CScalarData(scalar.Int64() - other.scalar.Int64()); } else { dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() : scalar.real(); - dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() : other.scalar.real(); + dyn_var y_n = + is_integer() ? (dyn_var)other.scalar.integer() : other.scalar.real(); ret.scalar = std::move(CScalarData(x_n - y_n)); } } return ret; } -static CFieldData add(const CFieldData& x, const CFieldData& y) { - return x + y; -} +static CFieldData add(const CFieldData& x, const CFieldData& y) { return x + y; } -static CFieldData sub(const CFieldData& x, const CFieldData& y) { - return x - y; -} +static CFieldData sub(const CFieldData& x, const CFieldData& y) { return x - y; } -static CFieldData div(const CFieldData& x, const CFieldData y) { +static CFieldData div(const CFieldData& x, const CFieldData y) { if (x.is_null() || y.is_null()) return CFieldData(); - if (!(x.is_integer() || x.is_real()) || !(y.is_integer() || y.is_real())) + if (!(x.is_integer() || x.is_real()) || !(y.is_integer() || y.is_real())) throw lgraph::CypherException("Type mismatch: expect Integer or Float in div expr"); CFieldData ret; if (x.is_integer() && y.is_integer()) { - dyn_var x_n = x.scalar.integer(); - dyn_var y_n = y.scalar.integer(); + dyn_var x_n = x.scalar.integer(); + dyn_var y_n = y.scalar.integer(); if (y_n == 0) throw lgraph::CypherException("divide by zero"); ret.scalar = std::move(CScalarData(x_n / y_n)); } else { - dyn_var x_n = x.is_integer() ? (dyn_var) x.scalar.integer() : x.scalar.real(); - dyn_var y_n = y.is_integer()? (dyn_var) y.scalar.integer() : y.scalar.real(); + dyn_var x_n = + x.is_integer() ? (dyn_var)x.scalar.integer() : x.scalar.real(); + dyn_var y_n = + y.is_integer() ? (dyn_var)y.scalar.integer() : y.scalar.real(); if (y_n == 0) CYPHER_TODO(); ret.scalar = std::move(CScalarData(x_n - y_n)); } @@ -77,13 +91,13 @@ static CFieldData div(const CFieldData& x, const CFieldData y) { } #ifndef DO_BINARY_EXPR -#define DO_BINARY_EXPR(func) \ +#define DO_BINARY_EXPR(func) \ auto lef = std::any_cast(node->left()->accept(*this)); \ auto rig = std::any_cast(node->right()->accept(*this)); \ - if (lef.type_ != CEntry::RecordEntryType::CONSTANT || \ - rig.type_ != CEntry::RecordEntryType::CONSTANT) { \ - NOT_SUPPORT_AND_THROW(); \ - } \ + if (lef.type_ != CEntry::RecordEntryType::CONSTANT || \ + rig.type_ != CEntry::RecordEntryType::CONSTANT) { \ + NOT_SUPPORT_AND_THROW(); \ + } \ return CEntry(func(lef.constant_, rig.constant_)); #endif @@ -176,5 +190,5 @@ std::any ExprEvaluator::visit(geax::frontend::IsNull* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::ListComprehension* node) { CYPHER_TODO(); } std::any ExprEvaluator::visit(geax::frontend::Exists* node) { CYPHER_TODO(); } std::any ExprEvaluator::reportError() { CYPHER_TODO(); } -} // namespace compilation -} // namepsace cypher \ No newline at end of file +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/jit/TuJIT.cpp b/src/cypher/experimental/jit/TuJIT.cpp index 78ad5f1237..260092e4fd 100644 --- a/src/cypher/experimental/jit/TuJIT.cpp +++ b/src/cypher/experimental/jit/TuJIT.cpp @@ -1,2 +1,15 @@ -#include "cypher/experimental/jit/TuJIT.h" +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/experimental/jit/TuJIT.h" diff --git a/src/cypher/experimental/jit/TuJIT.h b/src/cypher/experimental/jit/TuJIT.h index b90a4620e3..265af2be30 100644 --- a/src/cypher/experimental/jit/TuJIT.h +++ b/src/cypher/experimental/jit/TuJIT.h @@ -1,10 +1,26 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + #include #include #include -#include -#include -#include +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Target/TargetMachine.h" namespace cypher { namespace compilation { @@ -39,22 +55,24 @@ class TuJIT { // Compile module. In compile function client responsiblity is to fill module with necessary // IR code, then it will be compiled by TuJIT instance. // Return compiled module. - CompileModule compileModule(std::function compile_funciton); + CompileModule compileModule(std::function compile_funciton); - // Delete compiled module. Pointers to functions from module become invalid after this call. + // Delete compiled module. Pointers to functions from module become invalid after this call. // It is client reponsibility to be sure that there are no pointers to compiled module code. void deleteCompiledModule(const CompileModule& module_info); // Register external symbol for TuJIT instance to use, during linking. // It can be function, or global constant. - // It is client responsibility to be sure that address of symbol is valid during TuJIT instance lifetime. + // It is client responsibility to be sure that address of symbol is valid during TuJIT instance + // lifetime. void registerExternalSymbol(const std::string& symbol_name, void* address); // Total compiled code size for module that are current valid. - size_t getCompiledCodeSize() const { return compiled_code_size_.load(std::memory_order_relaxed); } + size_t getCompiledCodeSize() const { + return compiled_code_size_.load(std::memory_order_relaxed); + } private: - std::unique_ptr createModulerForCompilation(); CompileModule compileModule(std::unique_ptr module); @@ -70,10 +88,11 @@ class TuJIT { llvm::DataLayout layout_; std::unique_ptr compiler_; - std::unordered_map> module_identifier_to_memory_manager_; + std::unordered_map> + module_identifier_to_memory_manager_; uint64_t current_module_key_ = 0; std::atomic compiled_code_size_ = 0; mutable std::mutex jit_lock_; }; } // namespace compilation -} // namespace cypher \ No newline at end of file +} // namespace cypher diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp index f1cb88a265..40d1c661f2 100644 --- a/test/test_query_compilation.cpp +++ b/test/test_query_compilation.cpp @@ -1,3 +1,17 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + #include #include #include @@ -20,10 +34,10 @@ #include "builder/static_var.h" using builder::dyn_var; using builder::static_var; +using cypher::compilation::CEntry; using cypher::compilation::CFieldData; -using cypher::compilation::CScalarData; using cypher::compilation::CRecord; -using cypher::compilation::CEntry; +using cypher::compilation::CScalarData; #include "gtest/gtest.h" @@ -45,7 +59,7 @@ std::string execute(const std::string& command) { return result; } -std::string execute_func(std::string &func_body) { +std::string execute_func(std::string& func_body) { const std::string file_name = "test_add.cpp"; const std::string output_name = "test_add"; std::ofstream out_file(file_name); @@ -66,25 +80,26 @@ std::string execute_func(std::string &func_body) { std::string output = execute("./a"); // delete files if (std::remove(file_name.c_str()) && std::remove(output_name.c_str())) { - std::cerr << "Failed to delete files: " << file_name - << ", " << output_name << std::endl; + std::cerr << "Failed to delete files: " << file_name << ", " << output_name << std::endl; } return output; } class TestQueryCompilation : public TuGraphTest {}; -dyn_var add(void) { +dyn_var add(void) { cypher::SymbolTable sym_tab; CFieldData a(std::move(CScalarData(10))); geax::frontend::Ref ref1; ref1.setName(std::string("a")); - sym_tab.symbols.emplace("a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + sym_tab.symbols.emplace( + "a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); - CFieldData b(static_var(10)); + CFieldData b(static_var(10)); geax::frontend::Ref ref2; ref2.setName(std::string("b")); - sym_tab.symbols.emplace("b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + sym_tab.symbols.emplace( + "b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); geax::frontend::BAdd add; add.setLeft((geax::frontend::Expr*)&ref1); @@ -96,7 +111,7 @@ dyn_var add(void) { cypher::compilation::ExprEvaluator evaluator(&add, &sym_tab); cypher::RTContext ctx; return evaluator.Evaluate(&ctx, &record).constant_.scalar.Int64(); -}; +} TEST_F(TestQueryCompilation, Add) { builder::builder_context context; @@ -106,7 +121,7 @@ TEST_F(TestQueryCompilation, Add) { block::c_code_generator::generate_code(ast, oss, 0); oss << "int main() {\n std::cout << add();\n return 0;\n}"; std::string body = oss.str(); - std::cout <<"Generated code: \n" << body << std::endl; + std::cout << "Generated code: \n" << body << std::endl; std::string res = execute_func(body); ASSERT_EQ(res, "20"); -} \ No newline at end of file +} diff --git a/toolkits/lgraph_compilation.cpp b/toolkits/lgraph_compilation.cpp index 9a33388150..fae2a948bb 100644 --- a/toolkits/lgraph_compilation.cpp +++ b/toolkits/lgraph_compilation.cpp @@ -1,3 +1,17 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + #include #include "cypher/experimental/data_type/field_data.h" #include "cypher/experimental/data_type/record.h" @@ -10,30 +24,32 @@ #include "builder/builder_context.h" #include "builder/dyn_var.h" using namespace cypher::compilation; -using builder::static_var; using builder::dyn_var; +using builder::static_var; -dyn_var bar(void) { - std::variant, static_var> a; - std::variant, dyn_var> b; - a = (std::variant, static_var>)static_var(10); - b = dyn_var(10); - auto res = std::get>(a) + std::get>(b); +dyn_var bar(void) { + std::variant, static_var> a; + std::variant, dyn_var> b; + a = (std::variant, static_var>)static_var(10); + b = dyn_var(10); + auto res = std::get>(a) + std::get>(b); return res; } -dyn_var foo(void) { +dyn_var foo(void) { cypher::SymbolTable sym_tab; CFieldData a(std::move(CScalarData(10))); geax::frontend::Ref ref1; ref1.setName(std::string("a")); - sym_tab.symbols.emplace("a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + sym_tab.symbols.emplace( + "a", cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); - CFieldData b(static_var(10)); + CFieldData b(static_var(10)); geax::frontend::Ref ref2; ref2.setName(std::string("b")); - sym_tab.symbols.emplace("b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + sym_tab.symbols.emplace( + "b", cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); geax::frontend::BAdd add; add.setLeft((geax::frontend::Expr*)&ref1); @@ -49,8 +65,8 @@ dyn_var foo(void) { int main() { builder::builder_context context; - std::cout<<"#include "<" << std::endl; block::c_code_generator::generate_code(context.extract_function_ast(foo, "foo"), std::cout, 0); - std::cout<< "int main() {\n std::cout << foo() << std::endl;\n return 0;\n}"; + std::cout << "int main() {\n std::cout << foo() << std::endl;\n return 0;\n}"; return 0; -} \ No newline at end of file +}