From f368aa2c0eb60c4fe8431cb314453be23c65cf0d Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 12 Jun 2018 15:38:58 -0400 Subject: [PATCH 01/48] Refactor I/O part out of connection handle and simplify state machine logic. --- src/include/common/notifiable_task.h | 30 +- src/include/network/connection_handle.h | 165 ++--- .../network/connection_handle_factory.h | 84 --- src/include/network/marshal.h | 214 +++++- .../network/network_io_wrapper_factory.h | 64 ++ src/include/network/network_io_wrappers.h | 95 +++ src/include/network/network_state.h | 26 +- .../network/postgres_protocol_handler.h | 12 +- src/include/network/protocol_handler.h | 2 +- src/network/connection_handle.cpp | 684 +++--------------- src/network/connection_handler_task.cpp | 16 +- src/network/marshal.cpp | 9 - src/network/network_io_wrapper_factory.cpp | 88 +++ src/network/network_io_wrappers.cpp | 187 +++++ src/network/postgres_protocol_handler.cpp | 159 ++-- src/network/protocol_handler.cpp | 4 +- test/network/exception_test.cpp | 12 +- test/network/prepare_stmt_test.cpp | 12 +- test/network/select_all_test.cpp | 9 +- test/network/simple_query_test.cpp | 10 +- test/network/ssl_test.cpp | 11 +- 21 files changed, 912 insertions(+), 981 deletions(-) delete mode 100644 src/include/network/connection_handle_factory.h create mode 100644 src/include/network/network_io_wrapper_factory.h create mode 100644 src/include/network/network_io_wrappers.h create mode 100644 src/network/network_io_wrapper_factory.cpp create mode 100644 src/network/network_io_wrappers.cpp diff --git a/src/include/common/notifiable_task.h b/src/include/common/notifiable_task.h index 8ea65efb26b..e1572ab63b9 100644 --- a/src/include/common/notifiable_task.h +++ b/src/include/common/notifiable_task.h @@ -62,7 +62,6 @@ class NotifiableTask { */ inline int Id() const { return task_id_; } - /** * @brief Register an event with the event base associated with this * notifiable task. @@ -140,22 +139,19 @@ class NotifiableTask { return RegisterEvent(-1, EV_PERSIST, callback, arg); } - // TODO(tianyu): The original network code seems to do this as an - // optimization. Specifically it avoids new memory allocation by reusing - // an existing event. I am leaving this out until we get numbers. - // void UpdateEvent(struct event *event, int fd, short flags, - // event_callback_fn callback, void *arg, - // const struct timeval *timeout = nullptr) { - // PELOTON_ASSERT(!(events_.find(event) == events_.end())); - // EventUtil::EventDel(event); - // EventUtil::EventAssign(event, base_, fd, flags, callback, arg); - // EventUtil::EventAdd(event, timeout); - // } - // - // void UpdateManualEvent(struct event *event, event_callback_fn callback, - // void *arg) { - // UpdateEvent(event, -1, EV_PERSIST, callback, arg); - // } + void UpdateEvent(struct event *event, int fd, short flags, + event_callback_fn callback, void *arg, + const struct timeval *timeout = nullptr) { + PELOTON_ASSERT(!(events_.find(event) == events_.end())); + EventUtil::EventDel(event); + EventUtil::EventAssign(event, base_, fd, flags, callback, arg); + EventUtil::EventAdd(event, timeout); + } + + void UpdateManualEvent(struct event *event, event_callback_fn callback, + void *arg) { + UpdateEvent(event, -1, EV_PERSIST, callback, arg); + } /** * @brief Unregister the event given. The event is no longer active and its diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index c0ae311021c..c311af9299c 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -34,6 +34,7 @@ #include "network/connection_handler_task.h" #include "network_state.h" #include "protocol_handler.h" +#include "network/network_io_wrappers.h" #include #include @@ -41,56 +42,84 @@ namespace peloton { namespace network { -// TODO(tianyu) This class is not refactored in full as rewriting the logic is -// not cost-effective. However, readability -// improvement and other changes may become desirable in the future. Other than -// code clutter, responsibility assignment -// is not well thought-out in this class. Abstracting out some type of socket -// wrapper would be nice. /** * @brief A ConnectionHandle encapsulates all information about a client - * connection for its entire duration. - * One should not use the constructor to construct a new ConnectionHandle - * instance every time as it is expensive - * to allocate buffers. Instead, use the ConnectionHandleFactory. - * - * @see ConnectionHandleFactory + * connection for its entire duration. This includes a state machine and the + * necessary libevent infrastructure for a handler to work on this connection. */ class ConnectionHandle { public: /** - * Update the existing event to listen to the passed flags + * Constructs a new ConnectionHandle + * @param sock_fd Client's connection fd + * @param handler The handler responsible for this handle */ - void UpdateEventFlags(short flags); + ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler); - WriteState WritePackets(); - - std::string WriteBufferToString(); + /** + * @brief Signal to libevent that this ConnectionHandle is ready to handle events + * + * This method needs to be called separately after initialization for the + * connection handle to do anything. The reason why this is not performed in + * the constructor is because it publishes pointers to this object. While the + * object should be fully initialized at that point, it's never a bad idea + * to be careful. + */ + inline void RegisterToReceiveEvents() { + workpool_event_ = conn_handler_->RegisterManualEvent( + METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); + + // TODO(Tianyi): should put the initialization else where.. check correctness + // first. + tcop_.SetTaskCallback([](void *arg) { + struct event *event = static_cast(arg); + event_active(event, EV_WRITE, 0); + }, workpool_event_); + + network_event_ = conn_handler_->RegisterEvent( + io_wrapper_->GetSocketFd(), EV_READ | EV_PERSIST, + METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); + } + /** + * Handles a libevent event. This simply delegates the the state machine. + */ inline void HandleEvent(int, short) { state_machine_.Accept(Transition::WAKEUP, *this); } - // Exposed for testing - const std::unique_ptr &GetProtocolHandler() const { - return protocol_handler_; + /* State Machine Actions */ + // TODO(Tianyu): Write some documentation when feeling like it + inline Transition TryRead() { + return io_wrapper_->FillReadBuffer(); } + Transition TryWrite(); + Transition Process(); + Transition GetResult(); + Transition TrySslHandshake(); + Transition CloseConnection(); - // State Machine actions /** - * refill_read_buffer - Used to repopulate read buffer with a fresh - * batch of data from the socket + * Updates the event flags of the network event. This configures how the handler + * reacts to client activity from this connection. + * @param flags new flags for the event handle. */ - Transition FillReadBuffer(); - Transition Wait(); - Transition Process(); - Transition ProcessWrite(); - Transition GetResult(); - Transition CloseSocket(); + inline void UpdateEventFlags(short flags) { + conn_handler_->UpdateEvent(network_event_, + io_wrapper_->GetSocketFd(), + flags, + METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), + this); + } + /** - * Flush out all the responses and do real SSL handshake + * Stops receiving network events from client connection. This is useful when + * we are waiting on peloton to return the result of a query and not handling + * client query. */ - Transition ProcessWrite_SSLHandshake(); + inline void StopReceivingNetworkEvent() { + EventUtil::EventDel(network_event_); + } private: /** @@ -145,55 +174,7 @@ class ConnectionHandle { }; friend class StateMachine; - friend class ConnectionHandleFactory; - - ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler, - std::shared_ptr rbuf, std::shared_ptr wbuf); - - /** - * Writes a packet's header (type, size) into the write buffer - */ - WriteState BufferWriteBytesHeader(OutputPacket *pkt); - - /** - * Writes a packet's content into the write buffer - */ - WriteState BufferWriteBytesContent(OutputPacket *pkt); - - /** - * Used to invoke a write into the Socket, returns false if the socket is not - * ready for write - */ - WriteState FlushWriteBuffer(); - - /** - * @brief: process SSL handshake to generate valid SSL - * connection context for further communications - * @return FINISH when the SSL handshake failed - * PROCEED when the SSL handshake success - * NEED_DATA when the SSL handshake is partially done due to network - * latency - */ - Transition SSLHandshake(); - - /** - * Set the socket to non-blocking mode - */ - inline void SetNonBlocking(evutil_socket_t fd) { - auto flags = fcntl(fd, F_GETFL); - flags |= O_NONBLOCK; - if (fcntl(fd, F_SETFL, flags) < 0) { - LOG_ERROR("Failed to set non-blocking socket"); - } - } - - /** - * Set TCP No Delay for lower latency - */ - inline void SetTCPNoDelay(evutil_socket_t fd) { - int one = 1; - setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof one); - } + friend class NetworkIoWrapperFactory; /** * @brief: Determine if there is still responses in the buffer @@ -202,27 +183,19 @@ class ConnectionHandle { */ inline bool HasResponse() { return (protocol_handler_->responses_.size() != 0) || - (wbuf_->buf_size != 0); + (io_wrapper_->wbuf_->size_ != 0); } - int sock_fd_; // socket file descriptor - struct event *network_event = nullptr; // something to read from network - struct event *workpool_event = nullptr; // worker thread done the job - - SSL *conn_SSL_context = nullptr; // SSL context for the connection - - ConnectionHandlerTask *handler_; // reference to the network thread - std::unique_ptr - protocol_handler_; // Stores state for this socket - tcop::TrafficCop traffic_cop_; - - std::shared_ptr rbuf_; // Socket's read buffer - std::shared_ptr wbuf_; // Socket's write buffer - unsigned int next_response_ = 0; // The next response in the response buffer + ConnectionHandlerTask *conn_handler_; + std::shared_ptr io_wrapper_; StateMachine state_machine_; + struct event *network_event_ = nullptr, *workpool_event_ = nullptr; + std::unique_ptr protocol_handler_ = nullptr; + tcop::TrafficCop tcop_; + // TODO(Tianyu): Put this into protocol handler in a later refactor + unsigned int next_response_ = 0; - short curr_event_flag_; // current libevent event flag }; } // namespace network } // namespace peloton diff --git a/src/include/network/connection_handle_factory.h b/src/include/network/connection_handle_factory.h deleted file mode 100644 index 8f81d1e20dd..00000000000 --- a/src/include/network/connection_handle_factory.h +++ /dev/null @@ -1,84 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// connection_handle_factory.h -// -// Identification: src/include/network/connection_handle_factory.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "network/connection_handle.h" -#include "peloton_server.h" - -namespace peloton { -namespace network { - -/** - * @brief Factory class for constructing ConnectionHandles - * The rationale behind using a factory is that buffers are expensive to - * allocate and allocating new - * ones every time is a bottleneck for throughput. - */ -class ConnectionHandleFactory { - public: - /** - * Creates or repurpose a ConnectionHandle to be run on the given handler, - * handling connection from conn_fd - * @param conn_fd Client connection fd. - * @param handler The handler this ConnectionHandle is assigned to - * @return - */ - std::shared_ptr GetConnectionHandle( - int conn_fd, ConnectionHandlerTask *handler) { - // TODO(tianyu): The use of a static variable here for testing purpose is - // VILE. Fix this in a later refactor - // (probably also to-do: beat up the person who wrote this) - PelotonServer::recent_connfd = conn_fd; - auto it = reusable_handles_.find(conn_fd); - if (it == reusable_handles_.end()) { - // We are not using std::make_shared here because we want to keep - // ConnectionHandle constructor - // private to avoid unintentional use. - auto handle = std::shared_ptr( - new ConnectionHandle(conn_fd, handler, std::make_shared(), - std::make_shared())); - reusable_handles_[conn_fd] = handle; - return handle; - } - - it->second->rbuf_->Reset(); - it->second->wbuf_->Reset(); - std::shared_ptr new_handle(new ConnectionHandle( - conn_fd, handler, it->second->rbuf_, it->second->wbuf_)); - reusable_handles_[conn_fd] = new_handle; - return new_handle; - } - - // TODO(tianyu) Again, this is VILE. Fix this in a later refactor. - /** - * Exposed for testing only. DO NOT USE ELSEWHERE IN CODE. - * @param conn_fd client socket fd - * @return ConnetionHandle object representing client connection at conn_fd - */ - std::shared_ptr ConnectionHandleAt(int conn_fd) { - return reusable_handles_[conn_fd]; - } - - // TODO(tianyu): This should removed with the rest of the singletons - // We are keeping this here as fixing singleton is not the focus of this - // refactor and fixing it would be pretty expensive. - static ConnectionHandleFactory &GetInstance() { - static ConnectionHandleFactory factory; - return factory; - } - - private: - std::unordered_map> reusable_handles_; -}; -} -} diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 030231ec56d..c1cfab612db 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -18,52 +18,212 @@ #include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" +#include +#include +#include "network/network_state.h" #define BUFFER_INIT_SIZE 100 namespace peloton { namespace network { -// Buffers used to batch messages at the socket +/** + * A plain old buffer with a movable cursor, the meaning of which is dependent + * on the use case. + * + * The buffer has a fix capacity and one can write a variable amount of meaningful + * bytes into it. We call this amount "size" of the buffer. + */ struct Buffer { - size_t buf_ptr; // buffer cursor - size_t buf_size; // buffer size - size_t buf_flush_ptr; // buffer cursor for write - ByteBuf buf; + public: + /** + * Instantiates a new buffer and reserve default many bytes. + */ + inline Buffer() { buf_.reserve(SOCKET_BUFFER_SIZE); } + + /** + * Reset the buffer pointer and clears content + */ + inline void Reset() { + size_ = 0; + offset_ = 0; + } - inline Buffer() : buf_ptr(0), buf_size(0), buf_flush_ptr(0) { - // capacity of the buffer - buf.reserve(SOCKET_BUFFER_SIZE); + /** + * @param bytes The amount of bytes to check between the cursor and the end + * of the buffer (defaults to 1) + * @return Whether there is any more bytes between the cursor and + * the end of the buffer + */ + inline bool HasMore(size_t bytes = 1) { return offset_ + bytes < size_; } + + /** + * @return Whether the buffer is at capacity. (All usable space is filled + * with meaningful bytes) + */ + inline bool Full() { return size_ == Capacity(); } + + /** + * @return Iterator to the beginning of the buffer + */ + inline ByteBuf::const_iterator Begin() { return std::begin(buf_); } + + /** + * @return Capacity of the buffer (not actual size) + */ + inline constexpr size_t Capacity() { return SOCKET_BUFFER_SIZE; } + + /** + * Shift contents to align the current cursor with start of the buffer, + * remove all bytes before the cursor. + */ + inline void MoveContentToHead() { + auto unprocessed_len = size_ - offset_; + std::memmove(&buf_[0], &buf_[offset_], unprocessed_len); + size_ = unprocessed_len; + offset_ = 0; } - inline void Reset() { - buf_ptr = 0; - buf_size = 0; - buf_flush_ptr = 0; + // TODO(Tianyu): Make these protected once we refactor protocol handler + size_t size_ = 0, offset_ = 0; + ByteBuf buf_; +}; + +/** + * A buffer specialize for read + */ +class ReadBuffer: public Buffer { + public: + /** + * Read as many bytes as possible using SSL read + * @param context SSL context to read from + * @return the return value of ssl read + */ + inline int FillBufferFrom(SSL *context) { + ERR_clear_error(); + ssize_t bytes_read = + SSL_read(context, &buf_[size_], Capacity() - size_); + int err = SSL_get_error(context, bytes_read); + if (err == SSL_ERROR_NONE) size_ += bytes_read; + return err; + }; + + /** + * Read as many bytes as possible using Posix from an fd + * @param fd the file descriptor to read from + * @return the return value of posix read + */ + inline int FillBufferFrom(int fd) { + ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); + if (bytes_read > 0) size_ += bytes_read; + return (int) bytes_read; } - // single buffer element accessor - inline uchar GetByte(size_t &index) { return buf[index]; } + /** + * The number of bytes available to be consumed (i.e. meaningful bytes after + * current read cursor) + * @return The number of bytes available to be consumed + */ + inline size_t BytesAvailable() { return size_ - offset_; } + + /** + * Read the given number of bytes into destination, advancing cursor by that + * number + * @param bytes Number of bytes to read + * @param dest Desired memory location to read into + */ + inline void Read(size_t bytes, void *dest) { + std::copy(buf_.begin() + offset_, + buf_.begin() + offset_ + bytes, + reinterpret_cast(dest)); + offset_ += bytes; + } - // Get pointer to index location - inline uchar *GetPtr(size_t index) { return &buf[index]; } + /** + * Read a value of type T off of the buffer, advancing cursor by appropriate + * amount + * @tparam T type of value to read off. Preferably a primitive type + * @return the value of type T + */ + template + inline T ReadValue() { + T result; + Read(sizeof(result), &result); + return result; + } +}; + +/** + * A buffer specialized for write + */ +class WriteBuffer: public Buffer { + public: + /** + * Write as many bytes as possible using SSL write + * @param context SSL context to write out to + * @return return value of SSL write + */ + inline int WriteOutTo(SSL *context) { + ERR_clear_error(); + ssize_t bytes_written = + SSL_write(context, &buf_[offset_], size_ - offset_); + int err = SSL_get_error(context, bytes_written); + if (err == SSL_ERROR_NONE) offset_ += bytes_written; + return err; + } - inline ByteBuf::const_iterator Begin() { return std::begin(buf); } + /** + * Write as many bytes as possible using Posix write to fd + * @param fd File descriptor to write out to + * @return return value of Posix write + */ + inline int WriteOutTo(int fd) { + ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); + if (bytes_written > 0) offset_ += bytes_written; + return (int) bytes_written; + } - inline ByteBuf::const_iterator End() { return std::end(buf); } + /** + * The remaining capacity of this buffer. This value is equal to the + * maximum capacity minus the capacity already in use. + * @return Remaining capacity + */ + inline size_t RemainingCapacity() { + return Capacity() - size_; + } - inline size_t GetMaxSize() { return SOCKET_BUFFER_SIZE; } + /** + * @param bytes Desired number of bytes to write + * @return Whether the buffer can accommodate the number of bytes given + */ + inline bool HasSpaceFor(size_t bytes) { + return RemainingCapacity() >= bytes; + } - // Get the 4 bytes Big endian uint32 and convert it to little endian - size_t GetUInt32BigEndian(); + /** + * Append the desired range into current buffer + * @tparam InputIt iterator type + * @param first beginning of range + * @param len length of range + */ + template + inline void Append(InputIt first, size_t len) { + std::copy(first, first + len, std::begin(buf_) + size_); + size_ += len; + } - // Is the requested amount of data available from the current position in - // the reader buffer? - inline bool IsReadDataAvailable(size_t bytes) { - return ((buf_ptr - 1) + bytes < buf_size); + /** + * Append the given value into the current buffer + * @tparam T input type + * @param val value to write into buffer + */ + template + inline void Append(T val) { + Append(&val, sizeof(T)); } }; + class InputPacket { public: NetworkMessageType msg_type; // header @@ -139,8 +299,8 @@ struct OutputPacket { NetworkMessageType msg_type; // header bool single_type_pkt; // there would be only a pkt type being written to the - // buffer when this flag is true - bool skip_header_write; // whether we should write header to socket wbuf + // buffer when this flag is true + bool skip_header_write; // whether we should write header to soc ket wbuf size_t write_ptr; // cursor used to write packet content to socket wbuf // TODO could packet be reused? diff --git a/src/include/network/network_io_wrapper_factory.h b/src/include/network/network_io_wrapper_factory.h new file mode 100644 index 00000000000..314fc11a694 --- /dev/null +++ b/src/include/network/network_io_wrapper_factory.h @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// connection_handle_factory.h +// +// Identification: src/include/network/connection_handle_factory.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "network/peloton_server.h" +#include "network/network_io_wrappers.h" + +namespace peloton { +namespace network { + +/** + * @brief Factory class for constructing NetworkIoWrapper objects + * Each NetworkIoWrapper is associated with read and write buffers that are + * expensive to reallocate on the fly. Thus, instead of destroying these wrapper + * objects when they are out of scope, we save them until we can transfer their + * buffers to other wrappers. + */ +// TODO(Tianyu): Make reuse more fine-grained and adjustable +// Currently there is no limit on the number of wrappers we save. This means that +// we never deallocated wrappers unless we shut down. Obviously this will be a +// memory overhead if we had a lot of connections at one point and dropped down +// after a while. Relying on OS fd values for reuse also can backfire. +// It shouldn't be hard to keep a pool of buffers with a size limit instead of +// a bunch of old wrapper objects. +class NetworkIoWrapperFactory { + public: + static inline NetworkIoWrapperFactory &GetInstance() { + static NetworkIoWrapperFactory factory; + return factory; + } + + /** + * @brief Creates or re-purpose a NetworkIoWrapper object for new use. + * The returned value always uses Posix I/O methods unles explicitly converted. + * @see NetworkIoWrapper for details + * @param conn_fd Client connection fd + * @return A new NetworkIoWrapper object + */ + std::shared_ptr NewNetworkIoWrapper(int conn_fd); + + /** + * @brief: process SSL handshake to generate valid SSL + * connection context for further communications + * @return FINISH when the SSL handshake failed + * PROCEED when the SSL handshake success + * NEED_DATA when the SSL handshake is partially done due to network + * latency + */ + Transition PerformSslHandshake(std::shared_ptr &io_wrapper); + private: + std::unordered_map> reusable_wrappers_; +}; +} +} diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h new file mode 100644 index 00000000000..6edef5ef7b4 --- /dev/null +++ b/src/include/network/network_io_wrappers.h @@ -0,0 +1,95 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// client_socket_wrapper.h +// +// Identification: src/include/network/client_socket_wrapper.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include "network/marshal.h" +#include "common/exception.h" +#include "common/utility.h" + +namespace peloton { +namespace network { + +/** + * A network io wrapper provides an interface for interacting with a client + * connection. + * + * Underneath the hood the wrapper buffers read and write, and can support posix + * and ssl reads and writes to the socket, depending on the concrete type at + * runtime. + * + * Because the buffers are large and expensive to allocate on fly, they are + * reused. Consequently, initialization of this class is handled by a factory + * class. @see NetworkIoWrapperFactory + */ +class NetworkIoWrapper { + friend class NetworkIoWrapperFactory; + public: + // TODO(Tianyu): Change and document after we refactor protocol handler + virtual Transition FillReadBuffer() = 0; + virtual Transition FlushWriteBuffer() = 0; + virtual Transition Close() = 0; + + inline int GetSocketFd() { return sock_fd_; } + Transition WritePacket(OutputPacket *pkt); + // TODO(Tianyu): Make these protected when protocol handler refactor is complete + NetworkIoWrapper(int sock_fd, + std::shared_ptr &rbuf, + std::shared_ptr &wbuf) + : sock_fd_(sock_fd), + conn_ssl_context_(nullptr), + rbuf_(std::move(rbuf)), + wbuf_(std::move(wbuf)) {} + // It is worth noting that because of the way we are reinterpret-casting between + // derived types, it is necessary that they share the same members. + int sock_fd_; + std::shared_ptr rbuf_; + std::shared_ptr wbuf_; + SSL *conn_ssl_context_; +}; + +/** + * A Network IoWrapper specialized for dealing with posix sockets. + */ +class PosixSocketIoWrapper : public NetworkIoWrapper { + public: + PosixSocketIoWrapper(int sock_fd, + std::shared_ptr rbuf, + std::shared_ptr wbuf); + + Transition FillReadBuffer() override; + Transition FlushWriteBuffer() override; + inline Transition Close() override { + peloton_close(sock_fd_); + return Transition::NONE; + } +}; + +/** + * NetworkIoWrapper specialized for dealing with ssl sockets. + */ +class SslSocketIoWrapper : public NetworkIoWrapper { + public: + // An SslSocketIoWrapper is always derived from a PosixSocketIoWrapper, + // as the handshake process happens over posix sockets. Use the method + // in NetworkIoWrapperFactory to get an SslSocketWrapper. + SslSocketIoWrapper() = delete; + + Transition FillReadBuffer() override; + Transition FlushWriteBuffer() override; + Transition Close() override; +}; +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/network_state.h b/src/include/network/network_state.h index 9530f84bd7d..34c08aa9adc 100644 --- a/src/include/network/network_state.h +++ b/src/include/network/network_state.h @@ -19,19 +19,11 @@ namespace network { * @see ConnectionHandle::StateMachine */ enum class ConnState { - READ, // State that reads data from the network - WRITE, // State the writes data to the network - PROCESS, // State that runs the network protocol on received data - CLOSING, // State for closing the client connection - GET_RESULT, // State when triggered by worker thread that completes the task. - PROCESS_WRITE_SSL_HANDSHAKE, // State to flush out responses and doing (Real) - // SSL handshake -}; - -// TODO(tianyu): Convert use cases of this to just return Transition -enum class WriteState { - COMPLETE, // Write completed - NOT_READY, // Socket not ready to write + READ, // State that reads data from the network + WRITE, // State the writes data to the network + PROCESS, // State that runs the network protocol on received data + CLOSING, // State for closing the client connection + SSL_INIT, // State to flush out responses and doing (Real) SSL handshake }; /** @@ -43,12 +35,12 @@ enum class Transition { NONE, WAKEUP, PROCEED, - NEED_DATA, + NEED_READ, // TODO(tianyu) generalize this symbol, this is currently only used in process - GET_RESULT, - FINISH, - RETRY, + NEED_RESULT, + TERMINATE, NEED_SSL_HANDSHAKE, + NEED_WRITE }; } // namespace network } // namespace peloton diff --git a/src/include/network/postgres_protocol_handler.h b/src/include/network/postgres_protocol_handler.h index ef75f0a4cb1..960e2fdfd46 100644 --- a/src/include/network/postgres_protocol_handler.h +++ b/src/include/network/postgres_protocol_handler.h @@ -27,7 +27,7 @@ #include "traffic_cop/traffic_cop.h" // Packet content macros -#define NULL_CONTENT_SIZE -1 +#define NULL_CONTENT_SIZE (-1) namespace peloton { @@ -51,7 +51,7 @@ class PostgresProtocolHandler : public ProtocolHandler { * to generate txn * @return @see ProcessResult */ - ProcessResult Process(Buffer &rbuf, size_t thread_id); + ProcessResult Process(ReadBuffer &rbuf, size_t thread_id); // Deserialize the parame types from packet static size_t ReadParamType(InputPacket *pkt, int num_params, @@ -85,7 +85,7 @@ class PostgresProtocolHandler : public ProtocolHandler { * (i.e. no type byte) * @return true if the parsing is complete */ - static bool ParseInputPacket(Buffer &rbuf, InputPacket &rpkt, + static bool ParseInputPacket(ReadBuffer &rbuf, InputPacket &rpkt, bool startup_format); /** @@ -95,14 +95,14 @@ class PostgresProtocolHandler : public ProtocolHandler { * @param rpkt the postgres rpkt we want to parse to * @return true if the parsing is complete */ - static bool ReadPacket(Buffer &rbuf, InputPacket &rpkt); + static bool ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt); /** * @brief Helper function to extract the header of a Postgres packet from the * read buffer * @see ParseInputPacket from param and return value */ - static bool ReadPacketHeader(Buffer &rbuf, InputPacket &rpkt, + static bool ReadPacketHeader(ReadBuffer &rbuf, InputPacket &rpkt, bool startup_format); //===--------------------------------------------------------------------===// @@ -196,8 +196,6 @@ class PostgresProtocolHandler : public ProtocolHandler { NetworkProtocolType protocol_type_; - // Manage standalone queries - // The result-column format code std::vector result_format_; diff --git a/src/include/network/protocol_handler.h b/src/include/network/protocol_handler.h index 1f8b4d283fe..0a7ccef3898 100644 --- a/src/include/network/protocol_handler.h +++ b/src/include/network/protocol_handler.h @@ -35,7 +35,7 @@ class ProtocolHandler { * Main switch case wrapper to process every packet apart from the startup * packet. Avoid flushing the response for extended protocols. */ - virtual ProcessResult Process(Buffer &rbuf, const size_t thread_id); + virtual ProcessResult Process(ReadBuffer &rbuf, size_t thread_id); virtual void Reset(); diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index e79564b5c4d..ffb9a0ac9b7 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -18,6 +18,7 @@ #include "network/peloton_server.h" #include "network/postgres_protocol_handler.h" #include "network/protocol_handler_factory.h" +#include "network/network_io_wrapper_factory.h" #include "settings/settings_manager.h" #include "common/utility.h" @@ -42,7 +43,7 @@ namespace network { * * state ::= DEFINE_STATE(ConnState) * transition list - * END_DEF + * END_STATE_DEF * * transition ::= * ON (Transition) SET_STATE_TO (ConnState) AND_INVOKE (ConnectionHandle @@ -50,6 +51,7 @@ namespace network { * * Note that all the symbols used must be defined in ConnState, Transition and * ClientSocketWrapper, respectively. + * */ namespace { // Underneath the hood these macro is defining the static method @@ -63,11 +65,7 @@ namespace { #define DEFINE_STATE(s) \ case ConnState::s: { \ switch (t) { -#define END_DEF \ - default: \ - throw std::runtime_error("undefined transition"); \ - } \ - } + #define ON(t) \ case Transition::t: \ return @@ -78,51 +76,70 @@ namespace { ([](ConnectionHandle & w) { return w.m(); }) \ } \ ; -#define AND_WAIT \ - ([](ConnectionHandle &) { return Transition::NONE; }) \ - } \ +#define AND_WAIT_ON_READ \ + ([](ConnectionHandle &w) { w.UpdateEventFlags(EV_READ | EV_PERSIST); \ + return Transition::NONE; }) \ + } \ ; +#define AND_WAIT_ON_WRITE \ + ([](ConnectionHandle &w) { w.UpdateEventFlags(EV_WRITE | EV_PERSIST); \ + return Transition::NONE; }) \ + } \ + ; +#define AND_WAIT_ON_PELOTON \ + ([](ConnectionHandle &w) { w.StopReceivingNetworkEvent(); \ + return Transition::NONE; }) \ + } \ + ; +#define END_DEF \ + default: \ + throw std::runtime_error("undefined transition"); \ + } \ + } + +#define END_STATE_DEF \ + ON(TERMINATE) SET_STATE_TO(CLOSING) AND_INVOKE(CloseConnection) \ + END_DEF } // clang-format off DEF_TRANSITION_GRAPH - DEFINE_STATE(READ) - ON(WAKEUP) SET_STATE_TO(READ) AND_INVOKE(FillReadBuffer) - ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) - ON(NEED_DATA) SET_STATE_TO(READ) AND_WAIT - ON(FINISH) SET_STATE_TO(CLOSING) AND_INVOKE(CloseSocket) - END_DEF - - DEFINE_STATE(PROCESS_WRITE_SSL_HANDSHAKE) - ON(WAKEUP) SET_STATE_TO(PROCESS_WRITE_SSL_HANDSHAKE) - AND_INVOKE(ProcessWrite_SSLHandshake) - ON(NEED_DATA) SET_STATE_TO(PROCESS_WRITE_SSL_HANDSHAKE) AND_WAIT - ON(FINISH) SET_STATE_TO(CLOSING) AND_INVOKE(CloseSocket) - ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) - END_DEF - - DEFINE_STATE(PROCESS) - ON(PROCEED) SET_STATE_TO(WRITE) AND_INVOKE(ProcessWrite) - ON(NEED_DATA) SET_STATE_TO(READ) AND_INVOKE(FillReadBuffer) - ON(GET_RESULT) SET_STATE_TO(GET_RESULT) AND_WAIT - ON(FINISH) SET_STATE_TO(CLOSING) AND_INVOKE(CloseSocket) - ON(NEED_SSL_HANDSHAKE) SET_STATE_TO(PROCESS_WRITE_SSL_HANDSHAKE) - AND_INVOKE(ProcessWrite_SSLHandshake) - END_DEF - - DEFINE_STATE(WRITE) - ON(WAKEUP) SET_STATE_TO(WRITE) AND_INVOKE(ProcessWrite) - ON(NEED_DATA) SET_STATE_TO(PROCESS) AND_INVOKE(Process) - ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) - END_DEF - - DEFINE_STATE(GET_RESULT) - ON(WAKEUP) SET_STATE_TO(GET_RESULT) AND_INVOKE(GetResult) - ON(PROCEED) SET_STATE_TO(WRITE) AND_INVOKE(ProcessWrite) - END_DEF - + DEFINE_STATE(READ) + ON(WAKEUP) SET_STATE_TO(READ) AND_INVOKE(TryRead) + ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) + ON(NEED_READ) SET_STATE_TO(READ) AND_WAIT_ON_READ + // This case happens only when we use SSL and are blocked on a write + // during handshake. From peloton's perspective we are still waiting + // for reads. + ON(NEED_WRITE) SET_STATE_TO(READ) AND_WAIT_ON_WRITE + END_STATE_DEF + + DEFINE_STATE(SSL_INIT) + ON(WAKEUP) SET_STATE_TO(SSL_INIT) AND_INVOKE(TrySslHandshake) + ON(NEED_READ) SET_STATE_TO(SSL_INIT) AND_WAIT_ON_READ + ON(NEED_WRITE) SET_STATE_TO(SSL_INIT) AND_WAIT_ON_WRITE + ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) + END_STATE_DEF + + DEFINE_STATE(PROCESS) + ON(WAKEUP) SET_STATE_TO(PROCESS) AND_INVOKE(GetResult) + ON(PROCEED) SET_STATE_TO(WRITE) AND_INVOKE(TryWrite) + ON(NEED_READ) SET_STATE_TO(READ) AND_INVOKE(TryRead) + // Client connections are ignored while we wait on peloton + // to execute the query + ON(NEED_RESULT) SET_STATE_TO(PROCESS) AND_WAIT_ON_PELOTON + ON(NEED_SSL_HANDSHAKE) SET_STATE_TO(SSL_INIT) AND_INVOKE(TrySslHandshake) + END_STATE_DEF + + DEFINE_STATE(WRITE) + ON(WAKEUP) SET_STATE_TO(WRITE) AND_INVOKE(TryWrite) + // This happens when doing ssl-rehandshake with client + ON(NEED_READ) SET_STATE_TO(WRITE) AND_WAIT_ON_READ + ON(NEED_WRITE) SET_STATE_TO(WRITE) AND_WAIT_ON_WRITE + ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) + END_STATE_DEF END_DEF - // clang-format on +// clang-format on void ConnectionHandle::StateMachine::Accept(Transition action, ConnectionHandle &connection) { @@ -134,566 +151,83 @@ void ConnectionHandle::StateMachine::Accept(Transition action, next = result.second(connection); } catch (NetworkProcessException &e) { LOG_ERROR("%s\n", e.what()); - connection.CloseSocket(); + connection.CloseConnection(); return; } } } -ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler, - std::shared_ptr rbuf, - std::shared_ptr wbuf) - : sock_fd_(sock_fd), - handler_(handler), - protocol_handler_(nullptr), - rbuf_(std::move(rbuf)), - wbuf_(std::move(wbuf)) { - SetNonBlocking(sock_fd_); - SetTCPNoDelay(sock_fd_); - - network_event = handler->RegisterEvent( - sock_fd_, EV_READ | EV_PERSIST, - METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - workpool_event = handler->RegisterManualEvent( - METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - - // TODO(Tianyu): should put the initialization else where.. check correctness - // first. - traffic_cop_.SetTaskCallback([](void *arg) { - struct event *event = static_cast(arg); - event_active(event, EV_WRITE, 0); - }, workpool_event); -} - -void ConnectionHandle::UpdateEventFlags(short flags) { - // TODO(tianyu): The original network code seems to do this as an - // optimization. I am leaving this out until we get numbers - // handler->UpdateEvent(network_event, sock_fd_, flags, - // METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - - if (flags == curr_event_flag_) return; - - handler_->UnregisterEvent(network_event); - network_event = handler_->RegisterEvent( - sock_fd_, flags, METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - - curr_event_flag_ = flags; +ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) + : conn_handler_(handler) { + // We will always handle connections using posix until (potentially) first SSL + // handshake. + io_wrapper_ = std::move(NetworkIoWrapperFactory::GetInstance() + .NewNetworkIoWrapper(sock_fd)); } -WriteState ConnectionHandle::WritePackets() { - // iterate through all the packets +Transition ConnectionHandle::TryWrite() { for (; next_response_ < protocol_handler_->responses_.size(); - next_response_++) { - auto pkt = protocol_handler_->responses_[next_response_].get(); - LOG_TRACE("To send packet with type: %c, len %lu", - static_cast(pkt->msg_type), pkt->len); - // write is not ready during write. transit to WRITE - auto result = BufferWriteBytesHeader(pkt); - if (result == WriteState::NOT_READY) return result; - result = BufferWriteBytesContent(pkt); - if (result == WriteState::NOT_READY) return result; + next_response_++) { + auto result = io_wrapper_->WritePacket( + protocol_handler_->responses_[next_response_].get()); + if (result != Transition::PROCEED) return result; } - - // Done writing all packets. clear packets protocol_handler_->responses_.clear(); next_response_ = 0; - - if (protocol_handler_->GetFlushFlag()) { - return FlushWriteBuffer(); - } - - // we have flushed, disable force flush now + if (protocol_handler_->GetFlushFlag()) + return io_wrapper_->FlushWriteBuffer(); protocol_handler_->SetFlushFlag(false); - - return WriteState::COMPLETE; -} - -Transition ConnectionHandle::FillReadBuffer() { - // This could be changed by SSL_ERROR_WANT_WRITE - // When we reenter, we need to recover. - UpdateEventFlags(EV_READ | EV_PERSIST); - - Transition result = Transition::NEED_DATA; - ssize_t bytes_read = 0; - bool done = false; - - // reset buffer if all the contents have been read - if (rbuf_->buf_ptr == rbuf_->buf_size) rbuf_->Reset(); - - // buf_ptr shouldn't overflow - PELOTON_ASSERT(rbuf_->buf_ptr <= rbuf_->buf_size); - - /* Do we have leftover data and are we at the end of the buffer? - * Move the data to the head of the buffer and clear out all the old data - * Note: The assumption here is that all the packets/headers till - * rbuf_.buf_ptr have been fully processed - */ - if (rbuf_->buf_ptr < rbuf_->buf_size && - rbuf_->buf_size == rbuf_->GetMaxSize()) { - auto unprocessed_len = rbuf_->buf_size - rbuf_->buf_ptr; - // Move this data to the head of rbuf_1 - std::memmove(rbuf_->GetPtr(0), rbuf_->GetPtr(rbuf_->buf_ptr), - unprocessed_len); - // update pointers - rbuf_->buf_ptr = 0; - rbuf_->buf_size = unprocessed_len; - } - - // return explicitly - while (!done) { - if (rbuf_->buf_size == rbuf_->GetMaxSize()) { - // we have filled the whole buffer, exit loop - done = true; - } else { - // try to fill the available space in the buffer - // if the connection is a SSL connection, we use SSL_read, otherwise - // we use general read function - if (conn_SSL_context != nullptr) { - ERR_clear_error(); - bytes_read = SSL_read(conn_SSL_context, rbuf_->GetPtr(rbuf_->buf_size), - rbuf_->GetMaxSize() - rbuf_->buf_size); - LOG_TRACE("SSL read successfully"); - int err = SSL_get_error(conn_SSL_context, bytes_read); - unsigned long ecode = - (err != SSL_ERROR_NONE || bytes_read < 0) ? ERR_get_error() : 0; - switch (err) { - case SSL_ERROR_NONE: { - // If successfully received, update buffer ptr and read status - // keep reading till no data is available or the buffer becomes full - rbuf_->buf_size += bytes_read; - result = Transition::PROCEED; - break; - } - - case SSL_ERROR_ZERO_RETURN: { - done = true; - result = Transition::FINISH; - break; - } - // The SSL packet is partially loaded to the SSL buffer only, - // More data is required in order to decode the whole packet. - case SSL_ERROR_WANT_READ: { - LOG_TRACE("SSL packet partially loaded to SSL buffer"); - done = true; - break; - } - // It happens when we're trying to rehandshake and we block on a write - // during the handshake. We need to wait on the socket to be writable - case SSL_ERROR_WANT_WRITE: { - LOG_TRACE("Rehandshake during write, block until writable"); - UpdateEventFlags(EV_WRITE | EV_PERSIST); - return Transition::NEED_DATA; - } - case SSL_ERROR_SYSCALL: { - // if interrupted, try again - if (errno == EINTR) { - LOG_INFO("Error SSL Reading: EINTR"); - break; - } - } - default: { - throw NetworkProcessException("SSL read error: %d, error code: " + - std::to_string(err) + " error code:" + - std::to_string(ecode)); - } - } - } else { - bytes_read = read(sock_fd_, rbuf_->GetPtr(rbuf_->buf_size), - rbuf_->GetMaxSize() - rbuf_->buf_size); - LOG_TRACE("When filling read buffer, read %ld bytes", bytes_read); - - if (bytes_read > 0) { - // read succeeded, update buffer size - rbuf_->buf_size += bytes_read; - result = Transition::PROCEED; - } else if (bytes_read == 0) { - return Transition::FINISH; - } else if (bytes_read < 0) { - // Nothing in the network pipe now - if (errno == EAGAIN || errno == EWOULDBLOCK) { - // return whatever results we have - done = true; - } else if (errno == EINTR) { - // interrupts are ok, try again - continue; - } else { - // some other error occured - LOG_ERROR("Error writing: %s", strerror(errno)); - throw NetworkProcessException("Error when filling read buffer " + - std::to_string(errno)); - } - } - } - } - } - return result; -} - -WriteState ConnectionHandle::FlushWriteBuffer() { - // This could be changed by unfinished write - // When we reenter, we need to recover it to read - UpdateEventFlags(EV_READ | EV_PERSIST); - - ssize_t written_bytes = 0; - // while we still have outstanding bytes to write - if (conn_SSL_context != nullptr) { - while (wbuf_->buf_size > 0) { - LOG_TRACE("SSL_write flush"); - ERR_clear_error(); - written_bytes = SSL_write( - conn_SSL_context, &wbuf_->buf[wbuf_->buf_flush_ptr], wbuf_->buf_size); - int err = SSL_get_error(conn_SSL_context, written_bytes); - unsigned long ecode = - (err != SSL_ERROR_NONE || written_bytes < 0) ? ERR_get_error() : 0; - switch (err) { - case SSL_ERROR_NONE: { - wbuf_->buf_flush_ptr += written_bytes; - wbuf_->buf_size -= written_bytes; - break; - } - case SSL_ERROR_WANT_WRITE: { - // The kernel will flush the network buffer automatically. What we - // need to do is to call SSL_write() again when the buffer becomes - // availble to write again(notified by Libevent). - UpdateEventFlags(EV_WRITE | EV_PERSIST); - LOG_TRACE("Flush write buffer, want write, not ready"); - return WriteState::NOT_READY; - } - case SSL_ERROR_WANT_READ: { - // It happens when doing rehandshake with client. - LOG_TRACE("Flush write buffer, want read, not ready"); - return WriteState::NOT_READY; - } - case SSL_ERROR_SYSCALL: { - // If interrupted, try again. - if (errno == EINTR) { - LOG_TRACE("Flush write buffer, eintr"); - break; - } - } - default: { - LOG_ERROR("SSL write error: %d, error code: %lu", err, ecode); - throw NetworkProcessException("SSL write error"); - } - } - } - } else { - while (wbuf_->buf_size > 0) { - written_bytes = 0; - while (written_bytes <= 0) { - LOG_TRACE("Normal write flush"); - written_bytes = - write(sock_fd_, &wbuf_->buf[wbuf_->buf_flush_ptr], wbuf_->buf_size); - // Write failed - if (written_bytes < 0) { - if (errno == EINTR) { - // interrupts are ok, try again - written_bytes = 0; - continue; - // Write would have blocked if the socket was - // in blocking mode. Wait till it's readable - } else if (errno == EAGAIN || errno == EWOULDBLOCK) { - // Listen for socket being enabled for write - UpdateEventFlags(EV_WRITE | EV_PERSIST); - // We should go to CONN_WRITE state - LOG_TRACE("WRITE NOT READY"); - return WriteState::NOT_READY; - } else { - // fatal errors - LOG_ERROR("Error writing: %s", strerror(errno)); - throw NetworkProcessException("Fatal error during write"); - } - } - - // weird edge case? - if (written_bytes == 0 && wbuf_->buf_size != 0) { - LOG_TRACE("Not all data is written"); - continue; - } - } - - // update book keeping - wbuf_->buf_flush_ptr += written_bytes; - wbuf_->buf_size -= written_bytes; - } - } - // buffer is empty - wbuf_->Reset(); - - // we are ok - return WriteState::COMPLETE; -} - -std::string ConnectionHandle::WriteBufferToString() { -#ifdef LOG_TRACE_ENABLED - LOG_TRACE("Write Buffer:"); - - for (size_t i = 0; i < wbuf_->buf_size; ++i) { - LOG_TRACE("%u", wbuf_->buf[i]); - } -#endif - - return std::string(wbuf_->buf.begin(), wbuf_->buf.end()); -} - -// TODO (Tianyi) Make this to be protocol specific -// Writes a packet's header (type, size) into the write buffer. -// Return false when the socket is not ready for write -WriteState ConnectionHandle::BufferWriteBytesHeader(OutputPacket *pkt) { - // If we should not write - if (pkt->skip_header_write) { - return WriteState::COMPLETE; - } - - size_t len = pkt->len; - unsigned char type = static_cast(pkt->msg_type); - int len_nb; // length in network byte order - - // check if we have enough space in the buffer - if (wbuf_->GetMaxSize() - wbuf_->buf_ptr < 1 + sizeof(int32_t)) { - // buffer needs to be flushed before adding header - auto result = FlushWriteBuffer(); - if (result == WriteState::NOT_READY) { - // Socket is not ready for write - return result; - } - } - - // assuming wbuf is now large enough to fit type and size fields in one go - if (type != 0) { - // type shouldn't be ignored - wbuf_->buf[wbuf_->buf_ptr++] = type; - } - - if (!pkt->single_type_pkt) { - // make len include its field size as well - len_nb = htonl(len + sizeof(int32_t)); - - // append the bytes of this integer in network-byte order - std::copy(reinterpret_cast(&len_nb), - reinterpret_cast(&len_nb) + 4, - std::begin(wbuf_->buf) + wbuf_->buf_ptr); - // move the write buffer pointer and update size of the socket buffer - wbuf_->buf_ptr += sizeof(int32_t); - } - - wbuf_->buf_size = wbuf_->buf_ptr; - - // Header is written to socket buf. No need to write it in the future - pkt->skip_header_write = true; - return WriteState::COMPLETE; -} - -// Writes a packet's content into the write buffer -// Return false when the socket is not ready for write -WriteState ConnectionHandle::BufferWriteBytesContent(OutputPacket *pkt) { - // the packet content to write - ByteBuf &pkt_buf = pkt->buf; - // the length of remaining content to write - size_t len = pkt->len; - // window is the size of remaining space in socket's wbuf - size_t window = 0; - - // fill the contents - while (len != 0) { - // calculate the remaining space in wbuf - window = wbuf_->GetMaxSize() - wbuf_->buf_ptr; - if (len <= window) { - // contents fit in the window, range copy "len" bytes - std::copy(std::begin(pkt_buf) + pkt->write_ptr, - std::begin(pkt_buf) + pkt->write_ptr + len, - std::begin(wbuf_->buf) + wbuf_->buf_ptr); - - // Move the cursor and update size of socket buffer - wbuf_->buf_ptr += len; - wbuf_->buf_size = wbuf_->buf_ptr; - LOG_TRACE("Content fit in window. Write content successful"); - return WriteState::COMPLETE; - } else { - // contents longer than socket buffer size, fill up the socket buffer - // with "window" bytes - - std::copy(std::begin(pkt_buf) + pkt->write_ptr, - std::begin(pkt_buf) + pkt->write_ptr + window, - std::begin(wbuf_->buf) + wbuf_->buf_ptr); - - // move the packet's cursor - pkt->write_ptr += window; - len -= window; - // Now the wbuf is full - wbuf_->buf_size = wbuf_->GetMaxSize(); - - LOG_TRACE("Content doesn't fit in window. Try flushing"); - auto result = FlushWriteBuffer(); - // flush before write the remaining content - if (result == WriteState::NOT_READY) { - // need to retry or close connection - return result; - } - } - } - return WriteState::COMPLETE; -} - -Transition ConnectionHandle::CloseSocket() { - LOG_DEBUG("Attempt to close the connection %d", sock_fd_); - // Remove listening event - handler_->UnregisterEvent(network_event); - handler_->UnregisterEvent(workpool_event); - - if (conn_SSL_context != nullptr) { - int shutdown_ret = 0; - ERR_clear_error(); - shutdown_ret = SSL_shutdown(conn_SSL_context); - if (shutdown_ret != 0) { - int err = SSL_get_error(conn_SSL_context, shutdown_ret); - if (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ) { - LOG_TRACE("SSL shutdown is not finished yet"); - return Transition::NEED_DATA; - } else { - LOG_ERROR("Error shutting down ssl session, err: %d", err); - } - } - SSL_free(conn_SSL_context); - conn_SSL_context = nullptr; - } - - peloton_close(sock_fd_); - return Transition::NONE; - -} - -Transition ConnectionHandle::ProcessWrite_SSLHandshake() { - // Flush out all the response first - if (HasResponse()) { - auto write_ret = ProcessWrite(); - if (write_ret != Transition::PROCEED) { - return write_ret; - } - } - - return SSLHandshake(); -} - -Transition ConnectionHandle::SSLHandshake() { - if (conn_SSL_context == nullptr) { - conn_SSL_context = SSL_new(PelotonServer::ssl_context); - if (conn_SSL_context == nullptr) { - throw NetworkProcessException("ssl context for conn failed"); - } - SSL_set_session_id_context(conn_SSL_context, nullptr, 0); - if (SSL_set_fd(conn_SSL_context, sock_fd_) == 0) { - LOG_ERROR("Failed to set SSL fd"); - return Transition::FINISH; - } - } - - // TODO(Yuchen): post-connection verification? - // clear current thread's error queue before any OpenSSL call - ERR_clear_error(); - int ssl_accept_ret = SSL_accept(conn_SSL_context); - if (ssl_accept_ret > 0) return Transition::PROCEED; - - int err = SSL_get_error(conn_SSL_context, ssl_accept_ret); - int ecode = ERR_get_error(); - char error_string[120]; - ERR_error_string(ecode, error_string); - switch (err) { - case SSL_ERROR_SSL: { - if (ecode < 0) { - LOG_ERROR("Could not accept SSL connection"); - } else { - LOG_ERROR( - "Could not accept SSL connection: EOF detected, " - "ssl_error_ssl, %s", - error_string); - } - return Transition::FINISH; - } - case SSL_ERROR_ZERO_RETURN: { - LOG_ERROR( - "Could not accept SSL connection: EOF detected, " - "ssl_error_zero_return, %s", - error_string); - return Transition::FINISH; - } - case SSL_ERROR_SYSCALL: { - if (ecode < 0) { - LOG_ERROR("Could not accept SSL connection, %s", error_string); - } else { - LOG_ERROR( - "Could not accept SSL connection: EOF detected, " - "ssl_sys_call, %s", - error_string); - } - return Transition::FINISH; - } - case SSL_ERROR_WANT_READ: { - UpdateEventFlags(EV_READ | EV_PERSIST); - return Transition::NEED_DATA; - } - case SSL_ERROR_WANT_WRITE: { - UpdateEventFlags(EV_WRITE | EV_PERSIST); - return Transition::NEED_DATA; - } - default: { - LOG_ERROR("Unrecognized SSL error code: %d", err); - return Transition::FINISH; - } - } + return Transition::PROCEED; } Transition ConnectionHandle::Process() { - if (protocol_handler_ == nullptr) { + // TODO(Tianyu): Just use Transition instead of ProcessResult, this looks + // like a 1 - 1 mapping between the two types. + if (protocol_handler_ == nullptr) // TODO(Tianyi) Check the rbuf here before we create one if we have // another protocol handler protocol_handler_ = ProtocolHandlerFactory::CreateProtocolHandler( - ProtocolHandlerType::Postgres, &traffic_cop_); - } + ProtocolHandlerType::Postgres, &tcop_); ProcessResult status = - protocol_handler_->Process(*rbuf_, (size_t)handler_->Id()); + protocol_handler_->Process(*(io_wrapper_->rbuf_), + (size_t) conn_handler_->Id()); switch (status) { - case ProcessResult::MORE_DATA_REQUIRED: - return Transition::NEED_DATA; - case ProcessResult::COMPLETE: - return Transition::PROCEED; - case ProcessResult::PROCESSING: - EventUtil::EventDel(network_event); - LOG_TRACE("ProcessResult: queueing"); - return Transition::GET_RESULT; - case ProcessResult::TERMINATE: - return Transition::FINISH; - case ProcessResult::NEED_SSL_HANDSHAKE: - return Transition::NEED_SSL_HANDSHAKE; + case ProcessResult::MORE_DATA_REQUIRED: return Transition::NEED_READ; + case ProcessResult::COMPLETE: return Transition::PROCEED; + case ProcessResult::PROCESSING: return Transition::NEED_RESULT; + case ProcessResult::TERMINATE: return Transition::TERMINATE; + case ProcessResult::NEED_SSL_HANDSHAKE: return Transition::NEED_SSL_HANDSHAKE; default: - LOG_ERROR("Unknown process result"); + LOG_ERROR("Unknown process result"); throw NetworkProcessException("Unknown process result"); } } -Transition ConnectionHandle::ProcessWrite() { - // TODO(tianyu): Should convert to use Transition in the top level method - switch (WritePackets()) { - case WriteState::COMPLETE: - UpdateEventFlags(EV_READ | EV_PERSIST); - return Transition::PROCEED; - case WriteState::NOT_READY: - return Transition::NONE; - } - throw NetworkProcessException("Unexpected write state"); -} - Transition ConnectionHandle::GetResult() { - // TODO(tianyu) We probably can collapse this state with some other state. - if (event_add(network_event, nullptr) < 0) { - LOG_ERROR("Failed to add event"); - PELOTON_ASSERT(false); - } + EventUtil::EventAdd(network_event_, nullptr); protocol_handler_->GetResult(); - traffic_cop_.SetQueuing(false); + tcop_.SetQueuing(false); return Transition::PROCEED; } + +Transition ConnectionHandle::TrySslHandshake() { + // Flush out all the response first + if (HasResponse()) { + auto write_ret = TryWrite(); + if (write_ret != Transition::PROCEED) return write_ret; + } + return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake(io_wrapper_); +} + +Transition ConnectionHandle::CloseConnection() { + LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); + // Remove listening event + conn_handler_->UnregisterEvent(network_event_); + io_wrapper_->Close(); + return Transition::NONE; + +} } // namespace network } // namespace peloton diff --git a/src/network/connection_handler_task.cpp b/src/network/connection_handler_task.cpp index 267c1e0ffd9..4f1f926928c 100644 --- a/src/network/connection_handler_task.cpp +++ b/src/network/connection_handler_task.cpp @@ -12,7 +12,7 @@ #include "network/connection_handler_task.h" #include "network/connection_handle.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" namespace peloton { namespace network { @@ -41,22 +41,26 @@ void ConnectionHandlerTask::Notify(int conn_fd) { void ConnectionHandlerTask::HandleDispatch(int new_conn_recv_fd, short) { // buffer used to receive messages from the main thread char client_fd[sizeof(int)]; - std::shared_ptr conn; size_t bytes_read = 0; // read fully while (bytes_read < sizeof(int)) { ssize_t result = read(new_conn_recv_fd, - client_fd + bytes_read, - sizeof(int) - bytes_read); + client_fd + bytes_read, + sizeof(int) - bytes_read); if (result < 0) { LOG_ERROR("Error when reading from dispatch"); } bytes_read += (size_t) result; } - conn = ConnectionHandleFactory::GetInstance().GetConnectionHandle( - *((int *) client_fd), this); + // Smart pointers are not used here because libevent does not take smart pointers. + // During the life time of this object, the pointer to it will be maintained + // by libevent rather than by our own code. The object will have to be cleaned + // up by one of its methods (i.e. we call a method with "delete this" and have + // the object commit suicide from libevent. ) + (new ConnectionHandle(*reinterpret_cast(client_fd), + this))->RegisterToReceiveEvents(); } } // namespace network diff --git a/src/network/marshal.cpp b/src/network/marshal.cpp index 974d292330d..9df869b3604 100644 --- a/src/network/marshal.cpp +++ b/src/network/marshal.cpp @@ -27,15 +27,6 @@ inline void CheckOverflow(UNUSED_ATTRIBUTE InputPacket *rpkt, PELOTON_ASSERT(rpkt->ptr + size - 1 < rpkt->len); } -size_t Buffer::GetUInt32BigEndian() { - size_t num = 0; - // directly converts from network byte order to little-endian - for (size_t i = buf_ptr; i < buf_ptr + sizeof(uint32_t); i++) { - num = (num << 8) | GetByte(i); - } - return num; -} - int PacketGetInt(InputPacket *rpkt, uchar base) { int value = 0; auto begin = rpkt->Begin() + rpkt->ptr; diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp new file mode 100644 index 00000000000..737e0831e7a --- /dev/null +++ b/src/network/network_io_wrapper_factory.cpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// connection_handle_factory.cpp +// +// Identification: src/network/connection_handle_factory.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "network/network_io_wrapper_factory.h" + +namespace peloton { +namespace network { +std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( + int conn_fd) { + auto it = reusable_wrappers_.find(conn_fd); + if (it == reusable_wrappers_.end()) { + // No reusable wrappers + auto wrapper = std::make_shared(conn_fd, + std::make_shared< + ReadBuffer>(), + std::make_shared< + WriteBuffer>()); + reusable_wrappers_[conn_fd] = + std::static_pointer_cast( + wrapper); + return wrapper; + } + + // Construct new wrapper by reusing buffers from the old one. + // The old one will be deallocated as we replace the last reference to it + // in the reusable_wrappers_ map + auto reused_wrapper = it->second; + reused_wrapper->rbuf_->Reset(); + reused_wrapper->wbuf_->Reset(); + reused_wrapper->sock_fd_ = conn_fd; + reused_wrapper->conn_ssl_context_ = nullptr; + // It is not necessary to have an explicit cast here because the reused + // wrapper always use Posix methods, as we never update their type in the + // reusable wrappers map. + auto new_wrapper = new std::shared_ptr( + reinterpret_cast(reused_wrapper.get())); + return reused_wrapper; +} + +Transition NetworkIoWrapperFactory::PerformSslHandshake(std::shared_ptr< + NetworkIoWrapper> &io_wrapper) { + if (io_wrapper->conn_ssl_context_ == nullptr) { + // Initial handshake, the incoming type is a posix socket wrapper + auto *context = io_wrapper->conn_ssl_context_ = + SSL_new(PelotonServer::ssl_context); + // TODO(Tianyu): Is it the right thing here to throw exceptions? + if (context == nullptr) + throw NetworkProcessException("ssl context for conn failed"); + SSL_set_session_id_context(context, nullptr, 0); + if (SSL_set_fd(context, io_wrapper->sock_fd_) == 0) + throw NetworkProcessException("Failed to set ssl fd"); + + // ssl handshake is done, need to use new methods for the original wrappers; + // We do not update the type in the reusable wrappers map because it is not + // relevant. + io_wrapper = new std::shared_ptr( + reinterpret_cast(io_wrapper.get())); + } + + // The wrapper already uses SSL methods. + // Yuchen: "Post-connection verification?" + auto *context = io_wrapper->conn_ssl_context_; + ERR_clear_error(); + int ssl_accept_ret = SSL_accept(context); + if (ssl_accept_ret > 0) return Transition::PROCEED; + + int err = SSL_get_error(context, ssl_accept_ret); + switch (err) { + case SSL_ERROR_WANT_READ:return Transition::NEED_READ; + case SSL_ERROR_WANT_WRITE:return Transition::NEED_WRITE; +// case SSL_ERROR_SSL: +// case SSL_ERROR_ZERO_RETURN: +// case SSL_ERROR_SYSCALL: + default:LOG_ERROR("SSL Error, error code %d", err); + return Transition::TERMINATE; + } +} +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp new file mode 100644 index 00000000000..ba7070449c7 --- /dev/null +++ b/src/network/network_io_wrappers.cpp @@ -0,0 +1,187 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// client_socket_wrapper.cpp +// +// Identification: src/network/client_socket_wrapper.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "network/network_io_wrappers.h" +#include +#include +#include +#include +#include "network/peloton_server.h" + +namespace peloton { +namespace network { +Transition NetworkIoWrapper::WritePacket(OutputPacket *pkt) { + // Write Packet Header + if (pkt->skip_header_write) return Transition::PROCEED; + + if (!wbuf_->HasSpaceFor(1 + sizeof(int32_t))) { + auto result = FlushWriteBuffer(); + if (FlushWriteBuffer() != Transition::PROCEED) + // Unable to flush buffer, socket presumably not ready for write + return result; + } + + wbuf_->Append(static_cast(pkt->msg_type)); + if (!pkt->single_type_pkt) + // Need to convert bytes to network order + wbuf_->Append(htonl(pkt->len + sizeof(int32_t))); + pkt->skip_header_write = true; + + // Write Packet Content + while (pkt->len != 0) { + if (wbuf_->HasSpaceFor(pkt->len)) + wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, pkt->len); + else { + auto write_size = wbuf_->RemainingCapacity(); + wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, write_size); + pkt->write_ptr += write_size; + auto result = FlushWriteBuffer(); + if (FlushWriteBuffer() != Transition::PROCEED) + // Unable to flush buffer, socket presumably not ready for write + return result; + } + } + return Transition::PROCEED; +} + +PosixSocketIoWrapper::PosixSocketIoWrapper(int sock_fd, + std::shared_ptr rbuf, + std::shared_ptr wbuf) + : NetworkIoWrapper(sock_fd, rbuf, wbuf) { + // Set Non Blocking + auto flags = fcntl(sock_fd_, F_GETFL); + flags |= O_NONBLOCK; + if (fcntl(sock_fd_, F_SETFL, flags) < 0) { + LOG_ERROR("Failed to set non-blocking socket"); + } + // Set TCP No Delay + int one = 1; + setsockopt(sock_fd_, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); +} + +Transition PosixSocketIoWrapper::FillReadBuffer() { + if (!rbuf_->HasMore()) rbuf_->Reset(); + if (rbuf_->HasMore() && rbuf_->Full()) rbuf_->MoveContentToHead(); + Transition result = Transition::NEED_READ; + // Normal mode + while (!rbuf_->Full()) { + auto bytes_read = rbuf_->FillBufferFrom(sock_fd_); + if (bytes_read > 0) + result = Transition::PROCEED; + else if (bytes_read == 0) + return Transition::TERMINATE; + else + switch (errno) { + case EAGAIN: + // Equal to EWOULDBLOCK + return result; + case EINTR: + continue; + default:LOG_ERROR("Error writing: %s", strerror(errno)); + throw NetworkProcessException("Error when filling read buffer " + + std::to_string(errno)); + } + } + return result; +} + +Transition PosixSocketIoWrapper::FlushWriteBuffer() { + while (wbuf_->HasMore()) { + auto bytes_written = wbuf_->WriteOutTo(sock_fd_); + if (bytes_written < 0) + switch (errno) { + case EINTR: continue; + case EAGAIN: return Transition::NEED_WRITE; + default:LOG_ERROR("Error writing: %s", strerror(errno)); + throw NetworkProcessException("Fatal error during write"); + } + } + wbuf_->Reset(); + return Transition::PROCEED; +} + +Transition SslSocketIoWrapper::FillReadBuffer() { + if (!rbuf_->HasMore()) rbuf_->Reset(); + if (rbuf_->HasMore() && rbuf_->Full()) rbuf_->MoveContentToHead(); + Transition result = Transition::NEED_READ; + while (!rbuf_->Full()) { + auto ret = rbuf_->FillBufferFrom(conn_ssl_context_); + switch (ret) { + case SSL_ERROR_NONE:result = Transition::PROCEED; + break; + case SSL_ERROR_ZERO_RETURN: return Transition::TERMINATE; + // The SSL packet is partially loaded to the SSL buffer only, + // More data is required in order to decode the wh`ole packet. + case SSL_ERROR_WANT_READ: return Transition::NEED_READ; + case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; + case SSL_ERROR_SYSCALL: + if (errno == EINTR) { + LOG_INFO("Error SSL Reading: EINTR"); + break; + } + // Intentional fallthrough + default: + throw NetworkProcessException( + "SSL read error: " + std::to_string(ret)); + } + } + return result; +} + +Transition SslSocketIoWrapper::FlushWriteBuffer() { + while (wbuf_->Full()) { + auto ret = wbuf_->WriteOutTo(conn_ssl_context_); + switch (ret) { + case SSL_ERROR_NONE: break; + case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; + case SSL_ERROR_WANT_READ: return Transition::NEED_READ; + case SSL_ERROR_SYSCALL: + // If interrupted, try again. + if (errno == EINTR) { + LOG_TRACE("Flush write buffer, eintr"); + break; + } + // Intentional Fallthrough + default:LOG_ERROR("SSL write error: %d, error code: %lu", + ret, + ERR_get_error()); + throw NetworkProcessException("SSL write error"); + } + + return Transition::PROCEED; +} + +Transition SslSocketIoWrapper::Close() { + ERR_clear_error(); + int ret = SSL_shutdown(conn_ssl_context_); + if (ret != 0) { + int err = SSL_get_error(conn_ssl_context_, ret); + switch (err) { + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + // More work to do before shutdown + return Transition::NEED_READ; + default:LOG_ERROR("Error shutting down ssl session, err: %d", err); + } + } + // SSL context is explicitly deallocated here because socket wrapper + // objects are saved reused for memory efficiency and the reuse might + // not happen immediately, and thus freeing it on reuse time can make this + // live on arbitrarily long. + SSL_free(conn_ssl_context_); + conn_ssl_context_ = nullptr; + peloton_close(sock_fd_); + return Transition::NONE; +} + +} // namespace network +} // namespace peloton diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index 644e8dfef16..8fa94fbd65c 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include "common/cache.h" #include "common/internal_types.h" @@ -32,7 +33,7 @@ #include "util/string_util.h" #define SSL_MESSAGE_VERNO 80877103 -#define PROTO_MAJOR_VERSION(x) (x >> 16) +#define PROTO_MAJOR_VERSION(x) ((x) >> 16) namespace peloton { namespace network { @@ -40,19 +41,19 @@ namespace network { // TODO: Remove hardcoded auth strings // Hardcoded authentication strings used during session startup. To be removed const std::unordered_map - // clang-format off +// clang-format off PostgresProtocolHandler::parameter_status_map_ = - boost::assign::map_list_of("application_name", "psql") - ("client_encoding", "UTF8") - ("DateStyle", "ISO, MDY") - ("integer_datetimes", "on") - ("IntervalStyle", "postgres") - ("is_superuser", "on") - ("server_encoding", "UTF8") - ("server_version", "9.5devel") - ("session_authorization", "postgres") - ("standard_conforming_strings", "on") - ("TimeZone", "US/Eastern"); + boost::assign::map_list_of("application_name", "psql") + ("client_encoding", "UTF8") + ("DateStyle", "ISO, MDY") + ("integer_datetimes", "on") + ("IntervalStyle", "postgres") + ("is_superuser", "on") + ("server_encoding", "UTF8") + ("server_version", "9.5devel") + ("session_authorization", "postgres") + ("standard_conforming_strings", "on") + ("TimeZone", "US/Eastern"); // clang-format on PostgresProtocolHandler::PostgresProtocolHandler(tcop::TrafficCop *traffic_cop) @@ -87,22 +88,20 @@ bool PostgresProtocolHandler::HardcodedExecuteFilter(QueryType query_type) { switch (query_type) { // Skip SET case QueryType::QUERY_SET: - case QueryType::QUERY_SHOW: - return false; - // Skip duplicate BEGIN + case QueryType::QUERY_SHOW:return false; + // Skip duplicate BEGIN case QueryType::QUERY_BEGIN: if (txn_state_ == NetworkTransactionStateType::BLOCK) { return false; } break; - // Skip duplicate Commits and Rollbacks + // Skip duplicate Commits and Rollbacks case QueryType::QUERY_COMMIT: case QueryType::QUERY_ROLLBACK: if (txn_state_ == NetworkTransactionStateType::IDLE) { return false; } - default: - break; + default:break; } return true; } @@ -188,7 +187,7 @@ ProcessResult PostgresProtocolHandler::ExecQueryMessage( if (cached_statement.get() != nullptr) { traffic_cop_->SetStatement(cached_statement); } - // Did not find statement with same name + // Did not find statement with same name else { std::string error_message = "The prepared statement does not exist"; SendErrorResponse( @@ -340,7 +339,7 @@ void PostgresProtocolHandler::ExecParseMessage(InputPacket *pkt) { // For empty query, we still want to get it constructed // TODO (Tianyi) Consider handle more statement bool empty = (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0); + sql_stmt_list->GetNumStatements() == 0); if (!empty) { parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); @@ -384,7 +383,7 @@ void PostgresProtocolHandler::ExecParseMessage(InputPacket *pkt) { // Stat if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { + settings::SettingId::stats_mode)) != StatsType::INVALID) { // Make a copy of param types for stat collection stats::QueryMetric::QueryParamBuf query_type_buf; query_type_buf.len = type_buf_len; @@ -446,8 +445,8 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { if (statement.get() == nullptr) { std::string error_message = statement_name.empty() - ? "Invalid unnamed statement" - : "The prepared statement does not exist"; + ? "Invalid unnamed statement" + : "The prepared statement does not exist"; LOG_ERROR("%s", error_message.c_str()); SendErrorResponse( {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); @@ -530,7 +529,7 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { std::shared_ptr param_stat(nullptr); if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID && + settings::SettingId::stats_mode)) != StatsType::INVALID && num_params > 0) { // Make a copy of format for stat collection stats::QueryMetric::QueryParamBuf param_format_buf; @@ -559,7 +558,7 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { if (itr != portals_.end()) { itr->second = portal_reference; } - // Create a new entry in portal map + // Create a new entry in portal map else { portals_.insert(std::make_pair(portal_name, portal_reference)); } @@ -619,9 +618,9 @@ size_t PostgresProtocolHandler::ReadParamValue( std::string param_str = std::string(std::begin(param), std::end(param)); bind_parameters[param_idx] = std::make_pair(type::TypeId::VARCHAR, param_str); - if ((unsigned int)param_idx >= param_types.size() || + if ((unsigned int) param_idx >= param_types.size() || PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx]) == + (PostgresValueType) param_types[param_idx]) == type::TypeId::VARCHAR) { param_values[param_idx] = type::ValueFactory::GetVarcharValue(param_str); @@ -629,9 +628,10 @@ size_t PostgresProtocolHandler::ReadParamValue( param_values[param_idx] = (type::ValueFactory::GetVarcharValue(param_str)) .CastAs(PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx])); + (PostgresValueType) param_types[param_idx])); } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != type::TypeId::INVALID); + PELOTON_ASSERT( + param_values[param_idx].GetTypeId() != type::TypeId::INVALID); } else { // BINARY mode PostgresValueType pg_value_type = @@ -711,7 +711,8 @@ size_t PostgresProtocolHandler::ReadParamValue( break; } } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != type::TypeId::INVALID); + PELOTON_ASSERT( + param_values[param_idx].GetTypeId() != type::TypeId::INVALID); } } } @@ -819,9 +820,8 @@ ProcessResult PostgresProtocolHandler::ExecExecuteMessage( void PostgresProtocolHandler::ExecExecuteMessageGetResult(ResultType status) { const auto &query_type = traffic_cop_->GetStatement()->GetQueryType(); switch (status) { - case ResultType::FAILURE: - LOG_ERROR("Failed to execute: %s", - traffic_cop_->GetErrorMessage().c_str()); + case ResultType::FAILURE:LOG_ERROR("Failed to execute: %s", + traffic_cop_->GetErrorMessage().c_str()); SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, traffic_cop_->GetErrorMessage()}}); return; @@ -861,12 +861,10 @@ void PostgresProtocolHandler::GetResult() { traffic_cop_->ExecuteStatementPlanGetResult(); auto status = traffic_cop_->ExecuteStatementGetResult(); switch (protocol_type_) { - case NetworkProtocolType::POSTGRES_JDBC: - LOG_TRACE("JDBC result"); + case NetworkProtocolType::POSTGRES_JDBC:LOG_TRACE("JDBC result"); ExecExecuteMessageGetResult(status); break; - case NetworkProtocolType::POSTGRES_PSQL: - LOG_TRACE("PSQL result"); + case NetworkProtocolType::POSTGRES_PSQL:LOG_TRACE("PSQL result"); ExecQueryMessageGetResult(status); } } @@ -901,17 +899,11 @@ void PostgresProtocolHandler::ExecCloseMessage(InputPacket *pkt) { responses_.push_back(std::move(response)); } -bool PostgresProtocolHandler::ParseInputPacket(Buffer &rbuf, InputPacket &rpkt, +bool PostgresProtocolHandler::ParseInputPacket(ReadBuffer &rbuf, + InputPacket &rpkt, bool startup_format) { - if (rpkt.header_parsed == false) { - // parse out the header first - if (ReadPacketHeader(rbuf, rpkt, startup_format) == false) { - // need more data - return false; - } - } - - PELOTON_ASSERT(rpkt.header_parsed == true); + if (!rpkt.header_parsed && !ReadPacketHeader(rbuf, rpkt, startup_format)) + return false; if (rpkt.is_initialized == false) { // packet needs to be initialized with rest of the contents @@ -926,50 +918,40 @@ bool PostgresProtocolHandler::ParseInputPacket(Buffer &rbuf, InputPacket &rpkt, // The function tries to do a preliminary read to fetch the size value and // then reads the rest of the packet. // Assume: Packet length field is always 32-bit int -bool PostgresProtocolHandler::ReadPacketHeader(Buffer &rbuf, InputPacket &rpkt, +bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, + InputPacket &rpkt, bool startup) { // All packets other than the startup packet have a 5 bytes header - size_t initial_read_size = startup ? sizeof(int32_t) : sizeof(int32_t) + 1; + size_t header_size = startup ? sizeof(int32_t) : sizeof(int32_t) + 1; // check if header bytes are available - if (!rbuf.IsReadDataAvailable(initial_read_size)) { - // nothing more to read - return false; - } - - if (!startup) { - // Header also contains msg type - rpkt.msg_type = static_cast(rbuf.GetByte(rbuf.buf_ptr)); - // Skip the message type byte - rbuf.buf_ptr++; - } + if (!rbuf.HasMore(header_size)) return false; + if (!startup) + rpkt.msg_type = rbuf.ReadValue(); // get packet size from the header // extract packet contents size // content lengths should exclude the length bytes - rpkt.len = rbuf.GetUInt32BigEndian() - sizeof(uint32_t); + rpkt.len = ntohl(rbuf.ReadValue()) - sizeof(uint32_t); // do we need to use the extended buffer for this packet? - rpkt.is_extended = (rpkt.len > rbuf.GetMaxSize()); + rpkt.is_extended = (rpkt.len > rbuf.Capacity()); if (rpkt.is_extended) { LOG_TRACE("Using extended buffer for pkt size:%ld", rpkt.len); // reserve space for the extended buffer rpkt.ReserveExtendedBuffer(); } - // we have processed the data, move buffer pointer - rbuf.buf_ptr += sizeof(int32_t); rpkt.header_parsed = true; - return true; } // Tries to read the contents of a single packet, returns true on success, false // on failure. -bool PostgresProtocolHandler::ReadPacket(Buffer &rbuf, InputPacket &rpkt) { +bool PostgresProtocolHandler::ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt) { if (rpkt.is_extended) { // extended packet mode - auto bytes_available = rbuf.buf_size - rbuf.buf_ptr; + auto bytes_available = rbuf.BytesAvailable(); auto bytes_required = rpkt.ExtendedBytesRequired(); // read minimum of the two ranges auto read_size = std::min(bytes_available, bytes_required); @@ -1059,17 +1041,14 @@ ProcessResult PostgresProtocolHandler::ProcessStartupPacket( return ProcessResult::COMPLETE; } -ProcessResult PostgresProtocolHandler::Process(Buffer &rbuf, - const size_t thread_id) { +ProcessResult PostgresProtocolHandler::Process(ReadBuffer &rbuf, + const size_t thread_id) { if (!ParseInputPacket(rbuf, request_, init_stage_)) return ProcessResult::MORE_DATA_REQUIRED; - ProcessResult process_status; - if (init_stage_) { - process_status = ProcessInitialPacket(&request_); - } else { - process_status = ProcessNormalPacket(&request_, thread_id); - } + ProcessResult process_status = init_stage_ + ? ProcessInitialPacket(&request_) + : ProcessNormalPacket(&request_, thread_id); request_.Reset(); @@ -1091,11 +1070,13 @@ ProcessResult PostgresProtocolHandler::ProcessNormalPacket( case NetworkMessageType::PARSE_COMMAND: { LOG_TRACE("PARSE_COMMAND"); ExecParseMessage(pkt); - } break; + } + break; case NetworkMessageType::BIND_COMMAND: { LOG_TRACE("BIND_COMMAND"); ExecBindMessage(pkt); - } break; + } + break; case NetworkMessageType::DESCRIBE_COMMAND: { LOG_TRACE("DESCRIBE_COMMAND"); return ExecDescribeMessage(pkt); @@ -1108,11 +1089,13 @@ ProcessResult PostgresProtocolHandler::ProcessNormalPacket( LOG_TRACE("SYNC_COMMAND"); SendReadyForQuery(txn_state_); SetFlushFlag(true); - } break; + } + break; case NetworkMessageType::CLOSE_COMMAND: { LOG_TRACE("CLOSE_COMMAND"); ExecCloseMessage(pkt); - } break; + } + break; case NetworkMessageType::TERMINATE_COMMAND: { LOG_TRACE("TERMINATE_COMMAND"); SetFlushFlag(true); @@ -1202,26 +1185,22 @@ void PostgresProtocolHandler::CompleteCommand(const QueryType &query_type, std::string tag = QueryTypeToString(query_type); switch (query_type) { /* After Begin, we enter a txn block */ - case QueryType::QUERY_BEGIN: - txn_state_ = NetworkTransactionStateType::BLOCK; + case QueryType::QUERY_BEGIN:txn_state_ = NetworkTransactionStateType::BLOCK; break; - /* After commit, we end the txn block */ + /* After commit, we end the txn block */ case QueryType::QUERY_COMMIT: - /* After rollback, the txn block is ended */ + /* After rollback, the txn block is ended */ case QueryType::QUERY_ROLLBACK: txn_state_ = NetworkTransactionStateType::IDLE; break; - case QueryType::QUERY_INSERT: - tag += " 0 " + std::to_string(rows); + case QueryType::QUERY_INSERT:tag += " 0 " + std::to_string(rows); break; case QueryType::QUERY_CREATE_TABLE: case QueryType::QUERY_CREATE_DB: case QueryType::QUERY_CREATE_INDEX: case QueryType::QUERY_CREATE_TRIGGER: - case QueryType::QUERY_PREPARE: - break; - default: - tag += " " + std::to_string(rows); + case QueryType::QUERY_PREPARE:break; + default:tag += " " + std::to_string(rows); } PacketPutStringWithTerminator(pkt.get(), tag); responses_.push_back(std::move(pkt)); diff --git a/src/network/protocol_handler.cpp b/src/network/protocol_handler.cpp index a6bbb110c22..20a56351f85 100644 --- a/src/network/protocol_handler.cpp +++ b/src/network/protocol_handler.cpp @@ -23,9 +23,7 @@ ProtocolHandler::ProtocolHandler(tcop::TrafficCop *traffic_cop) { ProtocolHandler::~ProtocolHandler() {} -ProcessResult ProtocolHandler::Process(UNUSED_ATTRIBUTE Buffer &rbuf, - UNUSED_ATTRIBUTE const size_t - thread_id) { +ProcessResult ProtocolHandler::Process(ReadBuffer &, const size_t) { return ProcessResult::TERMINATE; } diff --git a/test/network/exception_test.cpp b/test/network/exception_test.cpp index 73118f0478e..1175de77da0 100644 --- a/test/network/exception_test.cpp +++ b/test/network/exception_test.cpp @@ -16,7 +16,7 @@ #include "common/harness.h" #include "common/logger.h" #include "gtest/gtest.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" #include "network/postgres_protocol_handler.h" #include "network/protocol_handler_factory.h" @@ -74,16 +74,6 @@ void *ParserExceptionTest(int port) { "sslmode=disable application_name=psql", port)); - peloton::network::ConnectionHandle *conn = - peloton::network::ConnectionHandleFactory::GetInstance() - .ConnectionHandleAt(peloton::network::PelotonServer::recent_connfd) - .get(); - - network::PostgresProtocolHandler *handler = - dynamic_cast( - conn->GetProtocolHandler().get()); - EXPECT_NE(handler, nullptr); - // If an exception occurs on one transaction, we can not use this // transaction anymore int exception_count = 0, total = 6; diff --git a/test/network/prepare_stmt_test.cpp b/test/network/prepare_stmt_test.cpp index 3e11472dd54..4c76a37ecbb 100644 --- a/test/network/prepare_stmt_test.cpp +++ b/test/network/prepare_stmt_test.cpp @@ -17,7 +17,7 @@ #include "network/peloton_server.h" #include "network/postgres_protocol_handler.h" #include "util/string_util.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" namespace peloton { namespace test { @@ -41,16 +41,6 @@ void *PrepareStatementTest(int port) { LOG_INFO("[PrepareStatementTest] Connected to %s", C.dbname()); pqxx::work txn1(C); - peloton::network::ConnectionHandle *conn = - peloton::network::ConnectionHandleFactory::GetInstance().ConnectionHandleAt( - peloton::network::PelotonServer::recent_connfd).get(); - - //Check type of protocol handler - network::PostgresProtocolHandler* handler = - dynamic_cast(conn->GetProtocolHandler().get()); - - EXPECT_NE(handler, nullptr); - // create table and insert some data txn1.exec("DROP TABLE IF EXISTS employee;"); txn1.exec("CREATE TABLE employee(id INT, name VARCHAR(100));"); diff --git a/test/network/select_all_test.cpp b/test/network/select_all_test.cpp index d601537c585..1f5552b7aa9 100644 --- a/test/network/select_all_test.cpp +++ b/test/network/select_all_test.cpp @@ -15,7 +15,7 @@ #include "common/logger.h" #include "network/peloton_server.h" #include "network/protocol_handler_factory.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ #include "network/postgres_protocol_handler.h" @@ -40,13 +40,6 @@ void *SelectAllTest(int port) { pqxx::connection C(StringUtil::Format( "host=127.0.0.1 port=%d user=default_database sslmode=disable application_name=psql", port)); pqxx::work txn1(C); - peloton::network::ConnectionHandle *conn = - peloton::network::ConnectionHandleFactory::GetInstance().ConnectionHandleAt( - peloton::network::PelotonServer::recent_connfd).get(); - - network::PostgresProtocolHandler *handler = - dynamic_cast(conn->GetProtocolHandler().get()); - EXPECT_NE(handler, nullptr); // create table and insert some data txn1.exec("DROP TABLE IF EXISTS template;"); diff --git a/test/network/simple_query_test.cpp b/test/network/simple_query_test.cpp index eb728e3fd6e..8e2409f2621 100644 --- a/test/network/simple_query_test.cpp +++ b/test/network/simple_query_test.cpp @@ -18,7 +18,7 @@ #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ #include "network/postgres_protocol_handler.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" #define NUM_THREADS 1 @@ -41,14 +41,6 @@ void *SimpleQueryTest(int port) { "host=127.0.0.1 port=%d user=default_database sslmode=disable application_name=psql", port)); pqxx::work txn1(C); - peloton::network::ConnectionHandle *conn = - peloton::network::ConnectionHandleFactory::GetInstance().ConnectionHandleAt( - peloton::network::PelotonServer::recent_connfd).get(); - - network::PostgresProtocolHandler *handler = - dynamic_cast(conn->GetProtocolHandler().get()); - EXPECT_NE(handler, nullptr); - // EXPECT_EQ(conn->state, peloton::network::READ); // create table and insert some data txn1.exec("DROP TABLE IF EXISTS employee;"); diff --git a/test/network/ssl_test.cpp b/test/network/ssl_test.cpp index 555f069afbd..00ed3582109 100644 --- a/test/network/ssl_test.cpp +++ b/test/network/ssl_test.cpp @@ -14,7 +14,7 @@ #include "common/harness.h" #include "common/logger.h" #include "gtest/gtest.h" -#include "network/connection_handle_factory.h" +#include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" #include "network/postgres_protocol_handler.h" #include "network/protocol_handler_factory.h" @@ -58,15 +58,6 @@ void *TestRoutine(int port) { pqxx::work txn1(C); - peloton::network::ConnectionHandle *conn = - peloton::network::ConnectionHandleFactory::GetInstance() - .ConnectionHandleAt(peloton::network::PelotonServer::recent_connfd) - .get(); - - network::PostgresProtocolHandler *handler = - dynamic_cast( - conn->GetProtocolHandler().get()); - EXPECT_NE(handler, nullptr); // basic test // create table and insert some data From da0e3077694b31781207029db556f6e0132c01ae Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 12 Jun 2018 17:22:44 -0400 Subject: [PATCH 02/48] Fix off-by-one and byte buf bug --- src/include/network/marshal.h | 10 +++++----- src/include/network/network_io_wrappers.h | 4 ++-- src/network/connection_handle.cpp | 4 ++-- src/network/network_io_wrapper_factory.cpp | 5 +---- src/network/network_io_wrappers.cpp | 14 ++++++++------ src/network/postgres_protocol_handler.cpp | 12 ++++++------ 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index c1cfab612db..15b98656053 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -55,7 +55,7 @@ struct Buffer { * @return Whether there is any more bytes between the cursor and * the end of the buffer */ - inline bool HasMore(size_t bytes = 1) { return offset_ + bytes < size_; } + inline bool HasMore(size_t bytes = 1) { return offset_ + bytes <= size_; } /** * @return Whether the buffer is at capacity. (All usable space is filled @@ -71,7 +71,7 @@ struct Buffer { /** * @return Capacity of the buffer (not actual size) */ - inline constexpr size_t Capacity() { return SOCKET_BUFFER_SIZE; } + inline constexpr size_t Capacity() const { return SOCKET_BUFFER_SIZE; } /** * Shift contents to align the current cursor with start of the buffer, @@ -189,7 +189,7 @@ class WriteBuffer: public Buffer { * @return Remaining capacity */ inline size_t RemainingCapacity() { - return Capacity() - size_; + return Capacity() - size_ + 1; } /** @@ -202,7 +202,7 @@ class WriteBuffer: public Buffer { /** * Append the desired range into current buffer - * @tparam InputIt iterator type + * @tparam InputIt iterator type. * @param first beginning of range * @param len length of range */ @@ -219,7 +219,7 @@ class WriteBuffer: public Buffer { */ template inline void Append(T val) { - Append(&val, sizeof(T)); + Append(reinterpret_cast(&val), sizeof(T)); } }; diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 6edef5ef7b4..535691e7a0f 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -49,9 +49,9 @@ class NetworkIoWrapper { std::shared_ptr &rbuf, std::shared_ptr &wbuf) : sock_fd_(sock_fd), - conn_ssl_context_(nullptr), rbuf_(std::move(rbuf)), - wbuf_(std::move(wbuf)) {} + wbuf_(std::move(wbuf)), + conn_ssl_context_(nullptr) {} // It is worth noting that because of the way we are reinterpret-casting between // derived types, it is necessary that they share the same members. int sock_fd_; diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index ffb9a0ac9b7..d277c8ef23b 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -161,8 +161,8 @@ ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) : conn_handler_(handler) { // We will always handle connections using posix until (potentially) first SSL // handshake. - io_wrapper_ = std::move(NetworkIoWrapperFactory::GetInstance() - .NewNetworkIoWrapper(sock_fd)); + io_wrapper_ = NetworkIoWrapperFactory::GetInstance() + .NewNetworkIoWrapper(sock_fd); } Transition ConnectionHandle::TryWrite() { diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index 737e0831e7a..f3c7d289685 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -41,8 +41,6 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( // It is not necessary to have an explicit cast here because the reused // wrapper always use Posix methods, as we never update their type in the // reusable wrappers map. - auto new_wrapper = new std::shared_ptr( - reinterpret_cast(reused_wrapper.get())); return reused_wrapper; } @@ -62,8 +60,7 @@ Transition NetworkIoWrapperFactory::PerformSslHandshake(std::shared_ptr< // ssl handshake is done, need to use new methods for the original wrappers; // We do not update the type in the reusable wrappers map because it is not // relevant. - io_wrapper = new std::shared_ptr( - reinterpret_cast(io_wrapper.get())); + io_wrapper.reset(reinterpret_cast(io_wrapper.get())); } // The wrapper already uses SSL methods. diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index ba7070449c7..2ef1b971b1e 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -37,12 +37,14 @@ Transition NetworkIoWrapper::WritePacket(OutputPacket *pkt) { pkt->skip_header_write = true; // Write Packet Content - while (pkt->len != 0) { - if (wbuf_->HasSpaceFor(pkt->len)) + for (size_t len = pkt->len; len != 0;) { + if (wbuf_->HasSpaceFor(pkt->len)) { wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, pkt->len); - else { + break; + } else { auto write_size = wbuf_->RemainingCapacity(); wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, write_size); + len -= write_size; pkt->write_ptr += write_size; auto result = FlushWriteBuffer(); if (FlushWriteBuffer() != Transition::PROCEED) @@ -84,8 +86,7 @@ Transition PosixSocketIoWrapper::FillReadBuffer() { case EAGAIN: // Equal to EWOULDBLOCK return result; - case EINTR: - continue; + case EINTR:continue; default:LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Error when filling read buffer " + std::to_string(errno)); @@ -156,8 +157,9 @@ Transition SslSocketIoWrapper::FlushWriteBuffer() { ERR_get_error()); throw NetworkProcessException("SSL write error"); } + } - return Transition::PROCEED; + return Transition::PROCEED; } Transition SslSocketIoWrapper::Close() { diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index 8fa94fbd65c..629dfe1887e 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -955,10 +955,10 @@ bool PostgresProtocolHandler::ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt) { auto bytes_required = rpkt.ExtendedBytesRequired(); // read minimum of the two ranges auto read_size = std::min(bytes_available, bytes_required); - rpkt.AppendToExtendedBuffer(rbuf.Begin() + rbuf.buf_ptr, - rbuf.Begin() + rbuf.buf_ptr + read_size); + rpkt.AppendToExtendedBuffer(rbuf.Begin() + rbuf.offset_, + rbuf.Begin() + rbuf.offset_ + read_size); // data has been copied, move ptr - rbuf.buf_ptr += read_size; + rbuf.offset_ += read_size; if (bytes_required > bytes_available) { // more data needs to be read return false; @@ -967,14 +967,14 @@ bool PostgresProtocolHandler::ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt) { rpkt.InitializePacket(); return true; } else { - if (rbuf.IsReadDataAvailable(rpkt.len) == false) { + if (rbuf.HasMore(rpkt.len) == false) { // data not available yet, return return false; } // Initialize the packet's "contents" - rpkt.InitializePacket(rbuf.buf_ptr, rbuf.Begin()); + rpkt.InitializePacket(rbuf.offset_, rbuf.Begin()); // We have processed the data, move buffer pointer - rbuf.buf_ptr += rpkt.len; + rbuf.offset_ += rpkt.len; } return true; From 5e18e1d165345963910818fedfafda7c494f79bc Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 13 Jun 2018 13:43:43 -0400 Subject: [PATCH 03/48] Run formatter on all code in network layer --- .../network/connection_dispatcher_task.h | 2 +- src/include/network/connection_handle.h | 39 ++++--- src/include/network/connection_handler_task.h | 2 +- src/include/network/marshal.h | 34 +++--- .../network/network_io_wrapper_factory.h | 40 +++---- src/include/network/network_io_wrappers.h | 24 ++--- src/include/network/network_state.h | 10 +- .../network/peloton_rpc_handler_task.h | 1 - src/include/network/peloton_server.h | 4 +- .../network/protocol_handler_factory.h | 4 +- .../network/service/connection_manager.h | 2 +- src/include/network/service/message_queue.h | 4 +- src/include/network/service/peloton_client.h | 4 +- .../network/service/peloton_endpoint.h | 2 +- src/include/network/service/peloton_service.h | 2 +- src/include/network/service/rpc_channel.h | 6 +- src/include/network/service/rpc_client.h | 10 +- src/include/network/service/rpc_method.h | 4 +- src/include/network/service/rpc_server.h | 2 +- src/include/network/service/rpc_type.h | 2 +- src/include/network/service/rpc_utils.h | 2 +- src/include/network/service/tcp_connection.h | 10 +- src/include/network/service/tcp_listener.h | 4 +- src/network/README | 19 +++- src/network/connection_dispatcher_task.cpp | 2 +- src/network/connection_handle.cpp | 101 ++++++++++-------- src/network/connection_handler_task.cpp | 21 ++-- src/network/marshal.cpp | 4 +- src/network/network_io_wrapper_factory.cpp | 34 +++--- src/network/network_io_wrappers.cpp | 62 ++++++----- src/network/peloton_server.cpp | 28 ++--- src/network/postgres_protocol_handler.cpp | 89 +++++++-------- src/network/protocol_handler_factory.cpp | 6 +- src/network/service/connection_manager.cpp | 2 +- src/network/service/peloton_service.cpp | 10 +- src/network/service/rpc_channel.cpp | 14 +-- src/network/service/rpc_client.cpp | 2 +- src/network/service/rpc_server.cpp | 10 +- src/network/service/rpc_utils.cpp | 4 +- src/network/service/tcp_connection.cpp | 11 +- src/network/service/tcp_listener.cpp | 9 +- 41 files changed, 331 insertions(+), 311 deletions(-) diff --git a/src/include/network/connection_dispatcher_task.h b/src/include/network/connection_dispatcher_task.h index f7d2b62e6a3..0b97147622a 100644 --- a/src/include/network/connection_dispatcher_task.h +++ b/src/include/network/connection_dispatcher_task.h @@ -13,9 +13,9 @@ #pragma once #include "common/notifiable_task.h" -#include "network_state.h" #include "concurrency/epoch_manager_factory.h" #include "connection_handler_task.h" +#include "network_state.h" namespace peloton { namespace network { diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index c311af9299c..959d9af970e 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -32,9 +32,9 @@ #include "marshal.h" #include "network/connection_handler_task.h" +#include "network/network_io_wrappers.h" #include "network_state.h" #include "protocol_handler.h" -#include "network/network_io_wrappers.h" #include #include @@ -57,7 +57,8 @@ class ConnectionHandle { ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler); /** - * @brief Signal to libevent that this ConnectionHandle is ready to handle events + * @brief Signal to libevent that this ConnectionHandle is ready to handle + * events * * This method needs to be called separately after initialization for the * connection handle to do anything. The reason why this is not performed in @@ -69,12 +70,14 @@ class ConnectionHandle { workpool_event_ = conn_handler_->RegisterManualEvent( METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - // TODO(Tianyi): should put the initialization else where.. check correctness - // first. - tcop_.SetTaskCallback([](void *arg) { - struct event *event = static_cast(arg); - event_active(event, EV_WRITE, 0); - }, workpool_event_); + // TODO(Tianyi): should put the initialization else where.. check + // correctness first. + tcop_.SetTaskCallback( + [](void *arg) { + struct event *event = static_cast(arg); + event_active(event, EV_WRITE, 0); + }, + workpool_event_); network_event_ = conn_handler_->RegisterEvent( io_wrapper_->GetSocketFd(), EV_READ | EV_PERSIST, @@ -90,9 +93,7 @@ class ConnectionHandle { /* State Machine Actions */ // TODO(Tianyu): Write some documentation when feeling like it - inline Transition TryRead() { - return io_wrapper_->FillReadBuffer(); - } + inline Transition TryRead() { return io_wrapper_->FillReadBuffer(); } Transition TryWrite(); Transition Process(); Transition GetResult(); @@ -100,16 +101,14 @@ class ConnectionHandle { Transition CloseConnection(); /** - * Updates the event flags of the network event. This configures how the handler - * reacts to client activity from this connection. + * Updates the event flags of the network event. This configures how the + * handler reacts to client activity from this connection. * @param flags new flags for the event handle. */ inline void UpdateEventFlags(short flags) { - conn_handler_->UpdateEvent(network_event_, - io_wrapper_->GetSocketFd(), - flags, - METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), - this); + conn_handler_->UpdateEvent( + network_event_, io_wrapper_->GetSocketFd(), flags, + METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); } /** @@ -183,10 +182,9 @@ class ConnectionHandle { */ inline bool HasResponse() { return (protocol_handler_->responses_.size() != 0) || - (io_wrapper_->wbuf_->size_ != 0); + (io_wrapper_->wbuf_->size_ != 0); } - ConnectionHandlerTask *conn_handler_; std::shared_ptr io_wrapper_; StateMachine state_machine_; @@ -195,7 +193,6 @@ class ConnectionHandle { tcop::TrafficCop tcop_; // TODO(Tianyu): Put this into protocol handler in a later refactor unsigned int next_response_ = 0; - }; } // namespace network } // namespace peloton diff --git a/src/include/network/connection_handler_task.h b/src/include/network/connection_handler_task.h index c86ea1a24ce..44a34884e2e 100644 --- a/src/include/network/connection_handler_task.h +++ b/src/include/network/connection_handler_task.h @@ -19,9 +19,9 @@ #include +#include "common/container/lock_free_queue.h" #include "common/exception.h" #include "common/logger.h" -#include "common/container/lock_free_queue.h" #include "common/notifiable_task.h" namespace peloton { diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 15b98656053..c40031b1bac 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -15,11 +15,11 @@ #include #include +#include +#include #include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" -#include -#include #include "network/network_state.h" #define BUFFER_INIT_SIZE 100 @@ -31,8 +31,8 @@ namespace network { * A plain old buffer with a movable cursor, the meaning of which is dependent * on the use case. * - * The buffer has a fix capacity and one can write a variable amount of meaningful - * bytes into it. We call this amount "size" of the buffer. + * The buffer has a fix capacity and one can write a variable amount of + * meaningful bytes into it. We call this amount "size" of the buffer. */ struct Buffer { public: @@ -92,7 +92,7 @@ struct Buffer { /** * A buffer specialize for read */ -class ReadBuffer: public Buffer { +class ReadBuffer : public Buffer { public: /** * Read as many bytes as possible using SSL read @@ -101,8 +101,7 @@ class ReadBuffer: public Buffer { */ inline int FillBufferFrom(SSL *context) { ERR_clear_error(); - ssize_t bytes_read = - SSL_read(context, &buf_[size_], Capacity() - size_); + ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); int err = SSL_get_error(context, bytes_read); if (err == SSL_ERROR_NONE) size_ += bytes_read; return err; @@ -116,7 +115,7 @@ class ReadBuffer: public Buffer { inline int FillBufferFrom(int fd) { ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); if (bytes_read > 0) size_ += bytes_read; - return (int) bytes_read; + return (int)bytes_read; } /** @@ -133,8 +132,7 @@ class ReadBuffer: public Buffer { * @param dest Desired memory location to read into */ inline void Read(size_t bytes, void *dest) { - std::copy(buf_.begin() + offset_, - buf_.begin() + offset_ + bytes, + std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, reinterpret_cast(dest)); offset_ += bytes; } @@ -156,7 +154,7 @@ class ReadBuffer: public Buffer { /** * A buffer specialized for write */ -class WriteBuffer: public Buffer { +class WriteBuffer : public Buffer { public: /** * Write as many bytes as possible using SSL write @@ -165,8 +163,7 @@ class WriteBuffer: public Buffer { */ inline int WriteOutTo(SSL *context) { ERR_clear_error(); - ssize_t bytes_written = - SSL_write(context, &buf_[offset_], size_ - offset_); + ssize_t bytes_written = SSL_write(context, &buf_[offset_], size_ - offset_); int err = SSL_get_error(context, bytes_written); if (err == SSL_ERROR_NONE) offset_ += bytes_written; return err; @@ -180,7 +177,7 @@ class WriteBuffer: public Buffer { inline int WriteOutTo(int fd) { ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); if (bytes_written > 0) offset_ += bytes_written; - return (int) bytes_written; + return (int)bytes_written; } /** @@ -188,17 +185,13 @@ class WriteBuffer: public Buffer { * maximum capacity minus the capacity already in use. * @return Remaining capacity */ - inline size_t RemainingCapacity() { - return Capacity() - size_ + 1; - } + inline size_t RemainingCapacity() { return Capacity() - size_ + 1; } /** * @param bytes Desired number of bytes to write * @return Whether the buffer can accommodate the number of bytes given */ - inline bool HasSpaceFor(size_t bytes) { - return RemainingCapacity() >= bytes; - } + inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } /** * Append the desired range into current buffer @@ -223,7 +216,6 @@ class WriteBuffer: public Buffer { } }; - class InputPacket { public: NetworkMessageType msg_type; // header diff --git a/src/include/network/network_io_wrapper_factory.h b/src/include/network/network_io_wrapper_factory.h index 314fc11a694..979e6a18afd 100644 --- a/src/include/network/network_io_wrapper_factory.h +++ b/src/include/network/network_io_wrapper_factory.h @@ -2,9 +2,9 @@ // // Peloton // -// connection_handle_factory.h +// network_io_wrapper_factory.h // -// Identification: src/include/network/connection_handle_factory.h +// Identification: src/include/network/network_io_wrapper_factory.h // // Copyright (c) 2015-2018, Carnegie Mellon University Database Group // @@ -12,8 +12,8 @@ #pragma once -#include "network/peloton_server.h" #include "network/network_io_wrappers.h" +#include "network/peloton_server.h" namespace peloton { namespace network { @@ -26,12 +26,12 @@ namespace network { * buffers to other wrappers. */ // TODO(Tianyu): Make reuse more fine-grained and adjustable -// Currently there is no limit on the number of wrappers we save. This means that -// we never deallocated wrappers unless we shut down. Obviously this will be a -// memory overhead if we had a lot of connections at one point and dropped down -// after a while. Relying on OS fd values for reuse also can backfire. -// It shouldn't be hard to keep a pool of buffers with a size limit instead of -// a bunch of old wrapper objects. +// Currently there is no limit on the number of wrappers we save. This means +// that we never deallocated wrappers unless we shut down. Obviously this will +// be a memory overhead if we had a lot of connections at one point and dropped +// down after a while. Relying on OS fd values for reuse also can backfire. It +// shouldn't be hard to keep a pool of buffers with a size limit instead of a +// bunch of old wrapper objects. class NetworkIoWrapperFactory { public: static inline NetworkIoWrapperFactory &GetInstance() { @@ -41,7 +41,8 @@ class NetworkIoWrapperFactory { /** * @brief Creates or re-purpose a NetworkIoWrapper object for new use. - * The returned value always uses Posix I/O methods unles explicitly converted. + * The returned value always uses Posix I/O methods unles explicitly + * converted. * @see NetworkIoWrapper for details * @param conn_fd Client connection fd * @return A new NetworkIoWrapper object @@ -49,16 +50,17 @@ class NetworkIoWrapperFactory { std::shared_ptr NewNetworkIoWrapper(int conn_fd); /** - * @brief: process SSL handshake to generate valid SSL - * connection context for further communications - * @return FINISH when the SSL handshake failed - * PROCEED when the SSL handshake success - * NEED_DATA when the SSL handshake is partially done due to network - * latency - */ + * @brief: process SSL handshake to generate valid SSL + * connection context for further communications + * @return FINISH when the SSL handshake failed + * PROCEED when the SSL handshake success + * NEED_DATA when the SSL handshake is partially done due to network + * latency + */ Transition PerformSslHandshake(std::shared_ptr &io_wrapper); + private: std::unordered_map> reusable_wrappers_; }; -} -} +} // namespace network +} // namespace peloton diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 535691e7a0f..66c453626d8 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -2,9 +2,9 @@ // // Peloton // -// client_socket_wrapper.h +// network_io_wrappers.h // -// Identification: src/include/network/client_socket_wrapper.h +// Identification: src/include/network/network_io_wrappers.h // // Copyright (c) 2015-2018, Carnegie Mellon University Database Group // @@ -15,9 +15,9 @@ #include #include #include -#include "network/marshal.h" #include "common/exception.h" #include "common/utility.h" +#include "network/marshal.h" namespace peloton { namespace network { @@ -36,6 +36,7 @@ namespace network { */ class NetworkIoWrapper { friend class NetworkIoWrapperFactory; + public: // TODO(Tianyu): Change and document after we refactor protocol handler virtual Transition FillReadBuffer() = 0; @@ -44,16 +45,16 @@ class NetworkIoWrapper { inline int GetSocketFd() { return sock_fd_; } Transition WritePacket(OutputPacket *pkt); - // TODO(Tianyu): Make these protected when protocol handler refactor is complete - NetworkIoWrapper(int sock_fd, - std::shared_ptr &rbuf, + // TODO(Tianyu): Make these protected when protocol handler refactor is + // complete + NetworkIoWrapper(int sock_fd, std::shared_ptr &rbuf, std::shared_ptr &wbuf) : sock_fd_(sock_fd), rbuf_(std::move(rbuf)), wbuf_(std::move(wbuf)), conn_ssl_context_(nullptr) {} - // It is worth noting that because of the way we are reinterpret-casting between - // derived types, it is necessary that they share the same members. + // It is worth noting that because of the way we are reinterpret-casting + // between derived types, it is necessary that they share the same members. int sock_fd_; std::shared_ptr rbuf_; std::shared_ptr wbuf_; @@ -65,8 +66,7 @@ class NetworkIoWrapper { */ class PosixSocketIoWrapper : public NetworkIoWrapper { public: - PosixSocketIoWrapper(int sock_fd, - std::shared_ptr rbuf, + PosixSocketIoWrapper(int sock_fd, std::shared_ptr rbuf, std::shared_ptr wbuf); Transition FillReadBuffer() override; @@ -91,5 +91,5 @@ class SslSocketIoWrapper : public NetworkIoWrapper { Transition FlushWriteBuffer() override; Transition Close() override; }; -} // namespace network -} // namespace peloton \ No newline at end of file +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/network_state.h b/src/include/network/network_state.h index 34c08aa9adc..b580e0ff457 100644 --- a/src/include/network/network_state.h +++ b/src/include/network/network_state.h @@ -19,11 +19,11 @@ namespace network { * @see ConnectionHandle::StateMachine */ enum class ConnState { - READ, // State that reads data from the network - WRITE, // State the writes data to the network - PROCESS, // State that runs the network protocol on received data - CLOSING, // State for closing the client connection - SSL_INIT, // State to flush out responses and doing (Real) SSL handshake + READ, // State that reads data from the network + WRITE, // State the writes data to the network + PROCESS, // State that runs the network protocol on received data + CLOSING, // State for closing the client connection + SSL_INIT, // State to flush out responses and doing (Real) SSL handshake }; /** diff --git a/src/include/network/peloton_rpc_handler_task.h b/src/include/network/peloton_rpc_handler_task.h index 8abfa510af4..d32d236ffc3 100644 --- a/src/include/network/peloton_rpc_handler_task.h +++ b/src/include/network/peloton_rpc_handler_task.h @@ -29,7 +29,6 @@ class PelotonRpcServerImpl final : public PelotonService::Server { } }; - class PelotonRpcHandlerTask : public DedicatedThreadTask { public: explicit PelotonRpcHandlerTask(const char *address) : address_(address) {} diff --git a/src/include/network/peloton_server.h b/src/include/network/peloton_server.h index 6f592bd6166..e0baed54ef1 100644 --- a/src/include/network/peloton_server.h +++ b/src/include/network/peloton_server.h @@ -29,12 +29,12 @@ #include #include "common/container/lock_free_queue.h" +#include "common/dedicated_thread_owner.h" #include "common/exception.h" #include "common/logger.h" -#include "common/dedicated_thread_owner.h" +#include "common/notifiable_task.h" #include "connection_dispatcher_task.h" #include "network_state.h" -#include "common/notifiable_task.h" #include "protocol_handler.h" #include diff --git a/src/include/network/protocol_handler_factory.h b/src/include/network/protocol_handler_factory.h index de52613c551..c13cca250b2 100644 --- a/src/include/network/protocol_handler_factory.h +++ b/src/include/network/protocol_handler_factory.h @@ -32,5 +32,5 @@ class ProtocolHandlerFactory { static std::unique_ptr CreateProtocolHandler( ProtocolHandlerType type, tcop::TrafficCop *trafficCop); }; -} -} +} // namespace network +} // namespace peloton diff --git a/src/include/network/service/connection_manager.h b/src/include/network/service/connection_manager.h index 45e2c8a6dab..c833d54b399 100644 --- a/src/include/network/service/connection_manager.h +++ b/src/include/network/service/connection_manager.h @@ -15,8 +15,8 @@ #include #include -#include "common/synchronization/mutex_latch.h" #include "common/synchronization/condition.h" +#include "common/synchronization/mutex_latch.h" #include "network/service/tcp_connection.h" namespace peloton { diff --git a/src/include/network/service/message_queue.h b/src/include/network/service/message_queue.h index a6e00af8ccf..3893beda017 100644 --- a/src/include/network/service/message_queue.h +++ b/src/include/network/service/message_queue.h @@ -12,10 +12,10 @@ #pragma once +#include +#include #include #include -#include -#include namespace peloton { namespace network { diff --git a/src/include/network/service/peloton_client.h b/src/include/network/service/peloton_client.h index 23ffa9f48d6..d41c405205e 100644 --- a/src/include/network/service/peloton_client.h +++ b/src/include/network/service/peloton_client.h @@ -10,11 +10,11 @@ // //===----------------------------------------------------------------------===// +#include "abstract_service.pb.h" #include "common/logger.h" +#include "peloton_endpoint.h" #include "rpc_channel.h" #include "rpc_controller.h" -#include "abstract_service.pb.h" -#include "peloton_endpoint.h" #include diff --git a/src/include/network/service/peloton_endpoint.h b/src/include/network/service/peloton_endpoint.h index 6d831baad9d..9a92800175c 100644 --- a/src/include/network/service/peloton_endpoint.h +++ b/src/include/network/service/peloton_endpoint.h @@ -12,8 +12,8 @@ #pragma once -#include "rpc_server.h" #include "peloton_service.h" +#include "rpc_server.h" namespace peloton { namespace network { diff --git a/src/include/network/service/peloton_service.h b/src/include/network/service/peloton_service.h index 61fd78536bc..8f45bc546bd 100644 --- a/src/include/network/service/peloton_service.h +++ b/src/include/network/service/peloton_service.h @@ -13,8 +13,8 @@ #pragma once #include "peloton/proto/abstract_service.pb.h" -#include "rpc_server.h" #include "peloton_endpoint.h" +#include "rpc_server.h" //===--------------------------------------------------------------------===// // Implements AbstractPelotonService diff --git a/src/include/network/service/rpc_channel.h b/src/include/network/service/rpc_channel.h index 306c556d95d..c4d4f1e1710 100644 --- a/src/include/network/service/rpc_channel.h +++ b/src/include/network/service/rpc_channel.h @@ -12,15 +12,15 @@ #pragma once -#include #include +#include -#include "network_address.h" #include "common/logger.h" +#include "network_address.h" #include "peloton/proto/abstract_service.pb.h" -#include #include +#include namespace peloton { namespace network { diff --git a/src/include/network/service/rpc_client.h b/src/include/network/service/rpc_client.h index f5f945d73a6..d58b30b442e 100644 --- a/src/include/network/service/rpc_client.h +++ b/src/include/network/service/rpc_client.h @@ -12,12 +12,12 @@ #pragma once -#include "rpc_type.h" -#include "rpc_controller.h" -#include "rpc_channel.h" -#include "peloton_endpoint.h" -#include "peloton/proto/abstract_service.pb.h" #include "common/logger.h" +#include "peloton/proto/abstract_service.pb.h" +#include "peloton_endpoint.h" +#include "rpc_channel.h" +#include "rpc_controller.h" +#include "rpc_type.h" #include diff --git a/src/include/network/service/rpc_method.h b/src/include/network/service/rpc_method.h index 9d76bc86bde..79b68fbade1 100644 --- a/src/include/network/service/rpc_method.h +++ b/src/include/network/service/rpc_method.h @@ -12,9 +12,9 @@ #pragma once -#include -#include #include +#include +#include #include namespace peloton { diff --git a/src/include/network/service/rpc_server.h b/src/include/network/service/rpc_server.h index db9f5b73af8..310de4b8dbe 100644 --- a/src/include/network/service/rpc_server.h +++ b/src/include/network/service/rpc_server.h @@ -12,9 +12,9 @@ #pragma once +#include #include #include -#include #include "common/logger.h" #include "rpc_method.h" diff --git a/src/include/network/service/rpc_type.h b/src/include/network/service/rpc_type.h index 8a15cc98a02..18117638ab5 100644 --- a/src/include/network/service/rpc_type.h +++ b/src/include/network/service/rpc_type.h @@ -19,5 +19,5 @@ namespace service { enum MessageType { MSG_TYPE_INVALID = 0, MSG_TYPE_REQ, MSG_TYPE_REP }; } // namespace service -} // namespace message +} // namespace network } // namespace peloton diff --git a/src/include/network/service/rpc_utils.h b/src/include/network/service/rpc_utils.h index e663e4ca011..3b997460c55 100644 --- a/src/include/network/service/rpc_utils.h +++ b/src/include/network/service/rpc_utils.h @@ -25,5 +25,5 @@ namespace service { //===----------------------------------------------------------------------===// } // namespace service -} // namespace message +} // namespace network } // namespace peloton diff --git a/src/include/network/service/tcp_connection.h b/src/include/network/service/tcp_connection.h index ad1837878b3..20f2ef8e002 100644 --- a/src/include/network/service/tcp_connection.h +++ b/src/include/network/service/tcp_connection.h @@ -12,17 +12,17 @@ #pragma once -#include "rpc_server.h" +#include "common/logger.h" +#include "network_address.h" #include "rpc_channel.h" #include "rpc_controller.h" -#include "network_address.h" -#include "common/logger.h" +#include "rpc_server.h" -#include #include +#include +#include #include #include -#include #include diff --git a/src/include/network/service/tcp_listener.h b/src/include/network/service/tcp_listener.h index 61befb114cf..7f24f28eab9 100644 --- a/src/include/network/service/tcp_listener.h +++ b/src/include/network/service/tcp_listener.h @@ -12,9 +12,9 @@ #pragma once -#include -#include #include +#include +#include namespace peloton { namespace network { diff --git a/src/network/README b/src/network/README index e2a8b3c6030..c3f8d1318b9 100644 --- a/src/network/README +++ b/src/network/README @@ -1,7 +1,19 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// README +// +// Identification: src/network/README +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + ##################################################### -# # -# PELOTON WIRE # -# # +# # +#PELOTON WIRE # +# # ##################################################### Description @@ -55,4 +67,3 @@ Packets supported * RowDescription (T) * DataRow (D) * CommandComplete (C) - diff --git a/src/network/connection_dispatcher_task.cpp b/src/network/connection_dispatcher_task.cpp index ce5ce18ffdd..4c800bf1440 100644 --- a/src/network/connection_dispatcher_task.cpp +++ b/src/network/connection_dispatcher_task.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/connection_dispatcher_task.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index d277c8ef23b..d14837770b6 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/connection_handle.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -15,13 +15,13 @@ #include "network/connection_dispatcher_task.h" #include "network/connection_handle.h" +#include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" #include "network/postgres_protocol_handler.h" #include "network/protocol_handler_factory.h" -#include "network/network_io_wrapper_factory.h" -#include "settings/settings_manager.h" #include "common/utility.h" +#include "settings/settings_manager.h" namespace peloton { namespace network { @@ -65,42 +65,46 @@ namespace { #define DEFINE_STATE(s) \ case ConnState::s: { \ switch (t) { - #define ON(t) \ case Transition::t: \ return #define SET_STATE_TO(s) \ { \ - ConnState::s, -#define AND_INVOKE(m) \ - ([](ConnectionHandle & w) { return w.m(); }) \ - } \ + ConnState::s, +#define AND_INVOKE(m) \ + ([](ConnectionHandle &w) { return w.m(); }) \ + } \ ; -#define AND_WAIT_ON_READ \ - ([](ConnectionHandle &w) { w.UpdateEventFlags(EV_READ | EV_PERSIST); \ - return Transition::NONE; }) \ - } \ +#define AND_WAIT_ON_READ \ + ([](ConnectionHandle &w) { \ + w.UpdateEventFlags(EV_READ | EV_PERSIST); \ + return Transition::NONE; \ + }) \ + } \ ; -#define AND_WAIT_ON_WRITE \ - ([](ConnectionHandle &w) { w.UpdateEventFlags(EV_WRITE | EV_PERSIST); \ - return Transition::NONE; }) \ - } \ +#define AND_WAIT_ON_WRITE \ + ([](ConnectionHandle &w) { \ + w.UpdateEventFlags(EV_WRITE | EV_PERSIST); \ + return Transition::NONE; \ + }) \ + } \ ; -#define AND_WAIT_ON_PELOTON \ - ([](ConnectionHandle &w) { w.StopReceivingNetworkEvent(); \ - return Transition::NONE; }) \ - } \ +#define AND_WAIT_ON_PELOTON \ + ([](ConnectionHandle &w) { \ + w.StopReceivingNetworkEvent(); \ + return Transition::NONE; \ + }) \ + } \ ; -#define END_DEF \ - default: \ - throw std::runtime_error("undefined transition"); \ - } \ +#define END_DEF \ + default: \ + throw std::runtime_error("undefined transition"); \ + } \ } -#define END_STATE_DEF \ - ON(TERMINATE) SET_STATE_TO(CLOSING) AND_INVOKE(CloseConnection) \ - END_DEF -} +#define END_STATE_DEF \ + ON(TERMINATE) SET_STATE_TO(CLOSING) AND_INVOKE(CloseConnection) END_DEF +} // namespace // clang-format off DEF_TRANSITION_GRAPH @@ -139,10 +143,10 @@ DEF_TRANSITION_GRAPH ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) END_STATE_DEF END_DEF -// clang-format on + // clang-format on -void ConnectionHandle::StateMachine::Accept(Transition action, - ConnectionHandle &connection) { + void ConnectionHandle::StateMachine::Accept(Transition action, + ConnectionHandle &connection) { Transition next = action; while (next != Transition::NONE) { transition_result result = Delta_(current_state_, next); @@ -161,21 +165,20 @@ ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) : conn_handler_(handler) { // We will always handle connections using posix until (potentially) first SSL // handshake. - io_wrapper_ = NetworkIoWrapperFactory::GetInstance() - .NewNetworkIoWrapper(sock_fd); + io_wrapper_ = + NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd); } Transition ConnectionHandle::TryWrite() { for (; next_response_ < protocol_handler_->responses_.size(); - next_response_++) { + next_response_++) { auto result = io_wrapper_->WritePacket( protocol_handler_->responses_[next_response_].get()); if (result != Transition::PROCEED) return result; } protocol_handler_->responses_.clear(); next_response_ = 0; - if (protocol_handler_->GetFlushFlag()) - return io_wrapper_->FlushWriteBuffer(); + if (protocol_handler_->GetFlushFlag()) return io_wrapper_->FlushWriteBuffer(); protocol_handler_->SetFlushFlag(false); return Transition::PROCEED; } @@ -189,18 +192,22 @@ Transition ConnectionHandle::Process() { protocol_handler_ = ProtocolHandlerFactory::CreateProtocolHandler( ProtocolHandlerType::Postgres, &tcop_); - ProcessResult status = - protocol_handler_->Process(*(io_wrapper_->rbuf_), - (size_t) conn_handler_->Id()); + ProcessResult status = protocol_handler_->Process( + *(io_wrapper_->rbuf_), (size_t)conn_handler_->Id()); switch (status) { - case ProcessResult::MORE_DATA_REQUIRED: return Transition::NEED_READ; - case ProcessResult::COMPLETE: return Transition::PROCEED; - case ProcessResult::PROCESSING: return Transition::NEED_RESULT; - case ProcessResult::TERMINATE: return Transition::TERMINATE; - case ProcessResult::NEED_SSL_HANDSHAKE: return Transition::NEED_SSL_HANDSHAKE; + case ProcessResult::MORE_DATA_REQUIRED: + return Transition::NEED_READ; + case ProcessResult::COMPLETE: + return Transition::PROCEED; + case ProcessResult::PROCESSING: + return Transition::NEED_RESULT; + case ProcessResult::TERMINATE: + return Transition::TERMINATE; + case ProcessResult::NEED_SSL_HANDSHAKE: + return Transition::NEED_SSL_HANDSHAKE; default: - LOG_ERROR("Unknown process result"); + LOG_ERROR("Unknown process result"); throw NetworkProcessException("Unknown process result"); } } @@ -218,7 +225,8 @@ Transition ConnectionHandle::TrySslHandshake() { auto write_ret = TryWrite(); if (write_ret != Transition::PROCEED) return write_ret; } - return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake(io_wrapper_); + return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake( + io_wrapper_); } Transition ConnectionHandle::CloseConnection() { @@ -227,7 +235,6 @@ Transition ConnectionHandle::CloseConnection() { conn_handler_->UnregisterEvent(network_event_); io_wrapper_->Close(); return Transition::NONE; - } } // namespace network } // namespace peloton diff --git a/src/network/connection_handler_task.cpp b/src/network/connection_handler_task.cpp index 4f1f926928c..7d5a5114c78 100644 --- a/src/network/connection_handler_task.cpp +++ b/src/network/connection_handler_task.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/connection_handler_task.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -45,22 +45,21 @@ void ConnectionHandlerTask::HandleDispatch(int new_conn_recv_fd, short) { // read fully while (bytes_read < sizeof(int)) { - ssize_t result = read(new_conn_recv_fd, - client_fd + bytes_read, + ssize_t result = read(new_conn_recv_fd, client_fd + bytes_read, sizeof(int) - bytes_read); if (result < 0) { LOG_ERROR("Error when reading from dispatch"); } - bytes_read += (size_t) result; + bytes_read += (size_t)result; } - // Smart pointers are not used here because libevent does not take smart pointers. - // During the life time of this object, the pointer to it will be maintained - // by libevent rather than by our own code. The object will have to be cleaned - // up by one of its methods (i.e. we call a method with "delete this" and have - // the object commit suicide from libevent. ) - (new ConnectionHandle(*reinterpret_cast(client_fd), - this))->RegisterToReceiveEvents(); + // Smart pointers are not used here because libevent does not take smart + // pointers. During the life time of this object, the pointer to it will be + // maintained by libevent rather than by our own code. The object will have to + // be cleaned up by one of its methods (i.e. we call a method with "delete + // this" and have the object commit suicide from libevent. ) + (new ConnectionHandle(*reinterpret_cast(client_fd), this)) + ->RegisterToReceiveEvents(); } } // namespace network diff --git a/src/network/marshal.cpp b/src/network/marshal.cpp index 9df869b3604..314dca1d5ea 100644 --- a/src/network/marshal.cpp +++ b/src/network/marshal.cpp @@ -6,14 +6,14 @@ // // Identification: src/network/marshal.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// +#include "network/marshal.h" #include #include #include -#include "network/marshal.h" #include diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index f3c7d289685..383dc48dae9 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -2,9 +2,9 @@ // // Peloton // -// connection_handle_factory.cpp +// network_io_wrapper_factory.cpp // -// Identification: src/network/connection_handle_factory.cpp +// Identification: src/network/network_io_wrapper_factory.cpp // // Copyright (c) 2015-2018, Carnegie Mellon University Database Group // @@ -19,11 +19,9 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( auto it = reusable_wrappers_.find(conn_fd); if (it == reusable_wrappers_.end()) { // No reusable wrappers - auto wrapper = std::make_shared(conn_fd, - std::make_shared< - ReadBuffer>(), - std::make_shared< - WriteBuffer>()); + auto wrapper = std::make_shared( + conn_fd, std::make_shared(), + std::make_shared()); reusable_wrappers_[conn_fd] = std::static_pointer_cast( wrapper); @@ -44,12 +42,12 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( return reused_wrapper; } -Transition NetworkIoWrapperFactory::PerformSslHandshake(std::shared_ptr< - NetworkIoWrapper> &io_wrapper) { +Transition NetworkIoWrapperFactory::PerformSslHandshake( + std::shared_ptr &io_wrapper) { if (io_wrapper->conn_ssl_context_ == nullptr) { // Initial handshake, the incoming type is a posix socket wrapper auto *context = io_wrapper->conn_ssl_context_ = - SSL_new(PelotonServer::ssl_context); + SSL_new(PelotonServer::ssl_context); // TODO(Tianyu): Is it the right thing here to throw exceptions? if (context == nullptr) throw NetworkProcessException("ssl context for conn failed"); @@ -72,14 +70,14 @@ Transition NetworkIoWrapperFactory::PerformSslHandshake(std::shared_ptr< int err = SSL_get_error(context, ssl_accept_ret); switch (err) { - case SSL_ERROR_WANT_READ:return Transition::NEED_READ; - case SSL_ERROR_WANT_WRITE:return Transition::NEED_WRITE; -// case SSL_ERROR_SSL: -// case SSL_ERROR_ZERO_RETURN: -// case SSL_ERROR_SYSCALL: - default:LOG_ERROR("SSL Error, error code %d", err); + case SSL_ERROR_WANT_READ: + return Transition::NEED_READ; + case SSL_ERROR_WANT_WRITE: + return Transition::NEED_WRITE; + default: + LOG_ERROR("SSL Error, error code %d", err); return Transition::TERMINATE; } } -} // namespace network -} // namespace peloton \ No newline at end of file +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index 2ef1b971b1e..b90fd1a05dd 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -2,9 +2,9 @@ // // Peloton // -// client_socket_wrapper.cpp +// network_io_wrappers.cpp // -// Identification: src/network/client_socket_wrapper.cpp +// Identification: src/network/network_io_wrappers.cpp // // Copyright (c) 2015-2018, Carnegie Mellon University Database Group // @@ -13,8 +13,8 @@ #include "network/network_io_wrappers.h" #include #include -#include #include +#include #include "network/peloton_server.h" namespace peloton { @@ -86,10 +86,12 @@ Transition PosixSocketIoWrapper::FillReadBuffer() { case EAGAIN: // Equal to EWOULDBLOCK return result; - case EINTR:continue; - default:LOG_ERROR("Error writing: %s", strerror(errno)); + case EINTR: + continue; + default: + LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Error when filling read buffer " + - std::to_string(errno)); + std::to_string(errno)); } } return result; @@ -98,11 +100,13 @@ Transition PosixSocketIoWrapper::FillReadBuffer() { Transition PosixSocketIoWrapper::FlushWriteBuffer() { while (wbuf_->HasMore()) { auto bytes_written = wbuf_->WriteOutTo(sock_fd_); - if (bytes_written < 0) - switch (errno) { - case EINTR: continue; - case EAGAIN: return Transition::NEED_WRITE; - default:LOG_ERROR("Error writing: %s", strerror(errno)); + if (bytes_written < 0) switch (errno) { + case EINTR: + continue; + case EAGAIN: + return Transition::NEED_WRITE; + default: + LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Fatal error during write"); } } @@ -117,13 +121,17 @@ Transition SslSocketIoWrapper::FillReadBuffer() { while (!rbuf_->Full()) { auto ret = rbuf_->FillBufferFrom(conn_ssl_context_); switch (ret) { - case SSL_ERROR_NONE:result = Transition::PROCEED; + case SSL_ERROR_NONE: + result = Transition::PROCEED; break; - case SSL_ERROR_ZERO_RETURN: return Transition::TERMINATE; + case SSL_ERROR_ZERO_RETURN: + return Transition::TERMINATE; // The SSL packet is partially loaded to the SSL buffer only, // More data is required in order to decode the wh`ole packet. - case SSL_ERROR_WANT_READ: return Transition::NEED_READ; - case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; + case SSL_ERROR_WANT_READ: + return Transition::NEED_READ; + case SSL_ERROR_WANT_WRITE: + return Transition::NEED_WRITE; case SSL_ERROR_SYSCALL: if (errno == EINTR) { LOG_INFO("Error SSL Reading: EINTR"); @@ -131,8 +139,7 @@ Transition SslSocketIoWrapper::FillReadBuffer() { } // Intentional fallthrough default: - throw NetworkProcessException( - "SSL read error: " + std::to_string(ret)); + throw NetworkProcessException("SSL read error: " + std::to_string(ret)); } } return result; @@ -142,9 +149,12 @@ Transition SslSocketIoWrapper::FlushWriteBuffer() { while (wbuf_->Full()) { auto ret = wbuf_->WriteOutTo(conn_ssl_context_); switch (ret) { - case SSL_ERROR_NONE: break; - case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; - case SSL_ERROR_WANT_READ: return Transition::NEED_READ; + case SSL_ERROR_NONE: + break; + case SSL_ERROR_WANT_WRITE: + return Transition::NEED_WRITE; + case SSL_ERROR_WANT_READ: + return Transition::NEED_READ; case SSL_ERROR_SYSCALL: // If interrupted, try again. if (errno == EINTR) { @@ -152,9 +162,8 @@ Transition SslSocketIoWrapper::FlushWriteBuffer() { break; } // Intentional Fallthrough - default:LOG_ERROR("SSL write error: %d, error code: %lu", - ret, - ERR_get_error()); + default: + LOG_ERROR("SSL write error: %d, error code: %lu", ret, ERR_get_error()); throw NetworkProcessException("SSL write error"); } } @@ -172,7 +181,8 @@ Transition SslSocketIoWrapper::Close() { case SSL_ERROR_WANT_READ: // More work to do before shutdown return Transition::NEED_READ; - default:LOG_ERROR("Error shutting down ssl session, err: %d", err); + default: + LOG_ERROR("Error shutting down ssl session, err: %d", err); } } // SSL context is explicitly deallocated here because socket wrapper @@ -185,5 +195,5 @@ Transition SslSocketIoWrapper::Close() { return Transition::NONE; } -} // namespace network -} // namespace peloton +} // namespace network +} // namespace peloton diff --git a/src/network/peloton_server.cpp b/src/network/peloton_server.cpp index 6e72caefbbc..b667a42e932 100644 --- a/src/network/peloton_server.cpp +++ b/src/network/peloton_server.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// - #include #include #include "common/utility.h" @@ -80,16 +79,16 @@ void PelotonServer::SSLLockingFunction(int mode, int n, } unsigned long PelotonServer::SSLIdFunction(void) { - return ((unsigned long) THREAD_ID); + return ((unsigned long)THREAD_ID); } void PelotonServer::LoadSSLFileSettings() { private_key_file_ = DATA_DIR + settings::SettingsManager::GetString( - settings::SettingId::private_key_file); + settings::SettingId::private_key_file); certificate_file_ = DATA_DIR + settings::SettingsManager::GetString( - settings::SettingId::certificate_file); + settings::SettingId::certificate_file); root_cert_file_ = DATA_DIR + settings::SettingsManager::GetString( - settings::SettingId::root_cert_file); + settings::SettingId::root_cert_file); } void PelotonServer::SSLInit() { @@ -111,7 +110,8 @@ void PelotonServer::SSLInit() { // TODO(Yuchen): deal with returned error 0? SSLMutexSetup(); // set general-purpose version, actual protocol will be negotiated to the - // highest version mutually support between client and server during handshake + // highest version mutually support between client and server during + // handshake ssl_context = SSL_CTX_new(SSLv23_method()); if (ssl_context == nullptr) { SetSSLLevel(SSLLevel::SSL_DISABLE); @@ -162,10 +162,10 @@ void PelotonServer::SSLInit() { // automatically. set routine to filter the return status of the default // verification and returns new verification status. SSL_VERIFY_PEER: send // certificate request to client. Client may ignore the request. If the - // client sends back the certificate, it will be verified. Handshake will be - // terminated if the verification fails. SSL_VERIFY_FAIL_IF_NO_PEER_CERT: use - // with SSL_VERIFY_PEER, if client does not send back the certificate, - // terminate the handshake. + // client sends back the certificate, it will be verified. Handshake will + // be terminated if the verification fails. + // SSL_VERIFY_FAIL_IF_NO_PEER_CERT: use with SSL_VERIFY_PEER, if client does + // not send back the certificate, terminate the handshake. SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, VerifyCallback); SSL_CTX_set_verify_depth(ssl_context, 4); } else { @@ -223,7 +223,7 @@ int PelotonServer::VerifyCallback(int ok, X509_STORE_CTX *store) { return ok; } -template +template void PelotonServer::TrySslOperation(int (*func)(Ts...), Ts... arg) { if (func(arg...) < 0) { auto error_message = peloton_error_message(); @@ -238,7 +238,7 @@ PelotonServer &PelotonServer::SetupServer() { // This line is critical to performance for some reason evthread_use_pthreads(); if (settings::SettingsManager::GetString( - settings::SettingId::socket_family) != "AF_INET") + settings::SettingId::socket_family) != "AF_INET") throw ConnectionException("Unsupported socket family"); struct sockaddr_in sin; @@ -258,13 +258,13 @@ PelotonServer &PelotonServer::SetupServer() { setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); TrySslOperation( - bind, listen_fd_, (struct sockaddr *) &sin, sizeof(sin)); + bind, listen_fd_, (struct sockaddr *)&sin, sizeof(sin)); TrySslOperation(listen, listen_fd_, conn_backlog); dispatcher_task_ = std::make_shared( CONNECTION_THREAD_COUNT, listen_fd_); - LOG_INFO("Listening on port %llu", (unsigned long long) port_); + LOG_INFO("Listening on port %llu", (unsigned long long)port_); return *this; } diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index 629dfe1887e..6f03a617667 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include "common/cache.h" #include "common/internal_types.h" @@ -21,8 +20,8 @@ #include "common/portal.h" #include "expression/expression_util.h" #include "network/marshal.h" -#include "network/postgres_protocol_handler.h" #include "network/peloton_server.h" +#include "network/postgres_protocol_handler.h" #include "parser/postgresparser.h" #include "parser/statements.h" #include "planner/plan_util.h" @@ -41,7 +40,7 @@ namespace network { // TODO: Remove hardcoded auth strings // Hardcoded authentication strings used during session startup. To be removed const std::unordered_map -// clang-format off + // clang-format off PostgresProtocolHandler::parameter_status_map_ = boost::assign::map_list_of("application_name", "psql") ("client_encoding", "UTF8") @@ -88,7 +87,8 @@ bool PostgresProtocolHandler::HardcodedExecuteFilter(QueryType query_type) { switch (query_type) { // Skip SET case QueryType::QUERY_SET: - case QueryType::QUERY_SHOW:return false; + case QueryType::QUERY_SHOW: + return false; // Skip duplicate BEGIN case QueryType::QUERY_BEGIN: if (txn_state_ == NetworkTransactionStateType::BLOCK) { @@ -101,7 +101,8 @@ bool PostgresProtocolHandler::HardcodedExecuteFilter(QueryType query_type) { if (txn_state_ == NetworkTransactionStateType::IDLE) { return false; } - default:break; + default: + break; } return true; } @@ -187,7 +188,7 @@ ProcessResult PostgresProtocolHandler::ExecQueryMessage( if (cached_statement.get() != nullptr) { traffic_cop_->SetStatement(cached_statement); } - // Did not find statement with same name + // Did not find statement with same name else { std::string error_message = "The prepared statement does not exist"; SendErrorResponse( @@ -257,8 +258,8 @@ ResultType PostgresProtocolHandler::ExecQueryExplain( std::unique_ptr unnamed_sql_stmt_list( new parser::SQLStatementList()); unnamed_sql_stmt_list->PassInStatement(std::move(explain_stmt.real_sql_stmt)); - auto stmt = traffic_cop_->PrepareStatement( - "explain", query, std::move(unnamed_sql_stmt_list)); + auto stmt = traffic_cop_->PrepareStatement("explain", query, + std::move(unnamed_sql_stmt_list)); ResultType status = ResultType::UNKNOWN; if (stmt != nullptr) { traffic_cop_->SetStatement(stmt); @@ -339,7 +340,7 @@ void PostgresProtocolHandler::ExecParseMessage(InputPacket *pkt) { // For empty query, we still want to get it constructed // TODO (Tianyi) Consider handle more statement bool empty = (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0); + sql_stmt_list->GetNumStatements() == 0); if (!empty) { parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); @@ -383,7 +384,7 @@ void PostgresProtocolHandler::ExecParseMessage(InputPacket *pkt) { // Stat if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { + settings::SettingId::stats_mode)) != StatsType::INVALID) { // Make a copy of param types for stat collection stats::QueryMetric::QueryParamBuf query_type_buf; query_type_buf.len = type_buf_len; @@ -445,8 +446,8 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { if (statement.get() == nullptr) { std::string error_message = statement_name.empty() - ? "Invalid unnamed statement" - : "The prepared statement does not exist"; + ? "Invalid unnamed statement" + : "The prepared statement does not exist"; LOG_ERROR("%s", error_message.c_str()); SendErrorResponse( {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); @@ -529,7 +530,7 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { std::shared_ptr param_stat(nullptr); if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID && + settings::SettingId::stats_mode)) != StatsType::INVALID && num_params > 0) { // Make a copy of format for stat collection stats::QueryMetric::QueryParamBuf param_format_buf; @@ -558,7 +559,7 @@ void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { if (itr != portals_.end()) { itr->second = portal_reference; } - // Create a new entry in portal map + // Create a new entry in portal map else { portals_.insert(std::make_pair(portal_name, portal_reference)); } @@ -618,9 +619,9 @@ size_t PostgresProtocolHandler::ReadParamValue( std::string param_str = std::string(std::begin(param), std::end(param)); bind_parameters[param_idx] = std::make_pair(type::TypeId::VARCHAR, param_str); - if ((unsigned int) param_idx >= param_types.size() || + if ((unsigned int)param_idx >= param_types.size() || PostgresValueTypeToPelotonValueType( - (PostgresValueType) param_types[param_idx]) == + (PostgresValueType)param_types[param_idx]) == type::TypeId::VARCHAR) { param_values[param_idx] = type::ValueFactory::GetVarcharValue(param_str); @@ -628,10 +629,10 @@ size_t PostgresProtocolHandler::ReadParamValue( param_values[param_idx] = (type::ValueFactory::GetVarcharValue(param_str)) .CastAs(PostgresValueTypeToPelotonValueType( - (PostgresValueType) param_types[param_idx])); + (PostgresValueType)param_types[param_idx])); } - PELOTON_ASSERT( - param_values[param_idx].GetTypeId() != type::TypeId::INVALID); + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); } else { // BINARY mode PostgresValueType pg_value_type = @@ -711,8 +712,8 @@ size_t PostgresProtocolHandler::ReadParamValue( break; } } - PELOTON_ASSERT( - param_values[param_idx].GetTypeId() != type::TypeId::INVALID); + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); } } } @@ -820,8 +821,9 @@ ProcessResult PostgresProtocolHandler::ExecExecuteMessage( void PostgresProtocolHandler::ExecExecuteMessageGetResult(ResultType status) { const auto &query_type = traffic_cop_->GetStatement()->GetQueryType(); switch (status) { - case ResultType::FAILURE:LOG_ERROR("Failed to execute: %s", - traffic_cop_->GetErrorMessage().c_str()); + case ResultType::FAILURE: + LOG_ERROR("Failed to execute: %s", + traffic_cop_->GetErrorMessage().c_str()); SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, traffic_cop_->GetErrorMessage()}}); return; @@ -861,10 +863,12 @@ void PostgresProtocolHandler::GetResult() { traffic_cop_->ExecuteStatementPlanGetResult(); auto status = traffic_cop_->ExecuteStatementGetResult(); switch (protocol_type_) { - case NetworkProtocolType::POSTGRES_JDBC:LOG_TRACE("JDBC result"); + case NetworkProtocolType::POSTGRES_JDBC: + LOG_TRACE("JDBC result"); ExecExecuteMessageGetResult(status); break; - case NetworkProtocolType::POSTGRES_PSQL:LOG_TRACE("PSQL result"); + case NetworkProtocolType::POSTGRES_PSQL: + LOG_TRACE("PSQL result"); ExecQueryMessageGetResult(status); } } @@ -925,8 +929,7 @@ bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, size_t header_size = startup ? sizeof(int32_t) : sizeof(int32_t) + 1; // check if header bytes are available if (!rbuf.HasMore(header_size)) return false; - if (!startup) - rpkt.msg_type = rbuf.ReadValue(); + if (!startup) rpkt.msg_type = rbuf.ReadValue(); // get packet size from the header // extract packet contents size @@ -1042,13 +1045,13 @@ ProcessResult PostgresProtocolHandler::ProcessStartupPacket( } ProcessResult PostgresProtocolHandler::Process(ReadBuffer &rbuf, - const size_t thread_id) { + const size_t thread_id) { if (!ParseInputPacket(rbuf, request_, init_stage_)) return ProcessResult::MORE_DATA_REQUIRED; - ProcessResult process_status = init_stage_ - ? ProcessInitialPacket(&request_) - : ProcessNormalPacket(&request_, thread_id); + ProcessResult process_status = + init_stage_ ? ProcessInitialPacket(&request_) + : ProcessNormalPacket(&request_, thread_id); request_.Reset(); @@ -1070,13 +1073,11 @@ ProcessResult PostgresProtocolHandler::ProcessNormalPacket( case NetworkMessageType::PARSE_COMMAND: { LOG_TRACE("PARSE_COMMAND"); ExecParseMessage(pkt); - } - break; + } break; case NetworkMessageType::BIND_COMMAND: { LOG_TRACE("BIND_COMMAND"); ExecBindMessage(pkt); - } - break; + } break; case NetworkMessageType::DESCRIBE_COMMAND: { LOG_TRACE("DESCRIBE_COMMAND"); return ExecDescribeMessage(pkt); @@ -1089,13 +1090,11 @@ ProcessResult PostgresProtocolHandler::ProcessNormalPacket( LOG_TRACE("SYNC_COMMAND"); SendReadyForQuery(txn_state_); SetFlushFlag(true); - } - break; + } break; case NetworkMessageType::CLOSE_COMMAND: { LOG_TRACE("CLOSE_COMMAND"); ExecCloseMessage(pkt); - } - break; + } break; case NetworkMessageType::TERMINATE_COMMAND: { LOG_TRACE("TERMINATE_COMMAND"); SetFlushFlag(true); @@ -1185,7 +1184,8 @@ void PostgresProtocolHandler::CompleteCommand(const QueryType &query_type, std::string tag = QueryTypeToString(query_type); switch (query_type) { /* After Begin, we enter a txn block */ - case QueryType::QUERY_BEGIN:txn_state_ = NetworkTransactionStateType::BLOCK; + case QueryType::QUERY_BEGIN: + txn_state_ = NetworkTransactionStateType::BLOCK; break; /* After commit, we end the txn block */ case QueryType::QUERY_COMMIT: @@ -1193,14 +1193,17 @@ void PostgresProtocolHandler::CompleteCommand(const QueryType &query_type, case QueryType::QUERY_ROLLBACK: txn_state_ = NetworkTransactionStateType::IDLE; break; - case QueryType::QUERY_INSERT:tag += " 0 " + std::to_string(rows); + case QueryType::QUERY_INSERT: + tag += " 0 " + std::to_string(rows); break; case QueryType::QUERY_CREATE_TABLE: case QueryType::QUERY_CREATE_DB: case QueryType::QUERY_CREATE_INDEX: case QueryType::QUERY_CREATE_TRIGGER: - case QueryType::QUERY_PREPARE:break; - default:tag += " " + std::to_string(rows); + case QueryType::QUERY_PREPARE: + break; + default: + tag += " " + std::to_string(rows); } PacketPutStringWithTerminator(pkt.get(), tag); responses_.push_back(std::move(pkt)); diff --git a/src/network/protocol_handler_factory.cpp b/src/network/protocol_handler_factory.cpp index 9e05939e8b5..9df0d5fad86 100644 --- a/src/network/protocol_handler_factory.cpp +++ b/src/network/protocol_handler_factory.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/protocol_handler_factory.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -26,5 +26,5 @@ std::unique_ptr ProtocolHandlerFactory::CreateProtocolHandler( return nullptr; } } -} -} +} // namespace network +} // namespace peloton diff --git a/src/network/service/connection_manager.cpp b/src/network/service/connection_manager.cpp index b374f9ce141..70693b5ce9a 100644 --- a/src/network/service/connection_manager.cpp +++ b/src/network/service/connection_manager.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/service/connection_manager.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// diff --git a/src/network/service/peloton_service.cpp b/src/network/service/peloton_service.cpp index 9e5095a0916..9367b898ff6 100644 --- a/src/network/service/peloton_service.cpp +++ b/src/network/service/peloton_service.cpp @@ -6,28 +6,28 @@ // // Identification: src/network/service/peloton_service.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// #include "network/service/peloton_service.h" -#include "network/service/peloton_endpoint.h" -#include "network/service/rpc_server.h" +#include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" #include "executor/plan_executor.h" +#include "network/service/peloton_endpoint.h" +#include "network/service/rpc_server.h" #include "planner/seq_scan_plan.h" #include "storage/tile.h" #include "storage/tuple.h" #include "type/serializeio.h" #include "type/serializer.h" -#include "common/internal_types.h" #include #include #include -#include #include +#include namespace peloton { namespace network { diff --git a/src/network/service/rpc_channel.cpp b/src/network/service/rpc_channel.cpp index f80d6223a3b..e3a8c6e0f68 100644 --- a/src/network/service/rpc_channel.cpp +++ b/src/network/service/rpc_channel.cpp @@ -6,23 +6,23 @@ // // Identification: src/network/service/rpc_channel.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// -#include "network/service/rpc_type.h" -#include "network/service/rpc_client.h" #include "network/service/rpc_channel.h" -#include "network/service/rpc_controller.h" -#include "network/service/tcp_connection.h" -#include "network/service/connection_manager.h" #include "common/logger.h" #include "common/macros.h" +#include "network/service/connection_manager.h" +#include "network/service/rpc_client.h" +#include "network/service/rpc_controller.h" +#include "network/service/rpc_type.h" +#include "network/service/tcp_connection.h" #include -#include #include +#include namespace peloton { namespace network { diff --git a/src/network/service/rpc_client.cpp b/src/network/service/rpc_client.cpp index 5bf066d5d45..3ca1d396b63 100644 --- a/src/network/service/rpc_client.cpp +++ b/src/network/service/rpc_client.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/service/rpc_client.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// diff --git a/src/network/service/rpc_server.cpp b/src/network/service/rpc_server.cpp index f799c622ef8..540ef4ea23b 100644 --- a/src/network/service/rpc_server.cpp +++ b/src/network/service/rpc_server.cpp @@ -6,19 +6,19 @@ // // Identification: src/network/service/rpc_server.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// #include "network/service/rpc_server.h" -#include "network/service/rpc_controller.h" -#include "network/service/connection_manager.h" -#include "common/logger.h" #include #include +#include "common/logger.h" +#include "network/service/connection_manager.h" +#include "network/service/rpc_controller.h" -#include #include +#include namespace peloton { namespace network { diff --git a/src/network/service/rpc_utils.cpp b/src/network/service/rpc_utils.cpp index e585f76c1eb..80d4d9e7dc7 100644 --- a/src/network/service/rpc_utils.cpp +++ b/src/network/service/rpc_utils.cpp @@ -6,7 +6,7 @@ // // Identification: src/network/service/rpc_utils.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -22,5 +22,5 @@ namespace service { //===----------------------------------------------------------------------===// } // namespace service -} // namespace message +} // namespace network } // namespace peloton diff --git a/src/network/service/tcp_connection.cpp b/src/network/service/tcp_connection.cpp index e00138a3055..e42ead200e8 100644 --- a/src/network/service/tcp_connection.cpp +++ b/src/network/service/tcp_connection.cpp @@ -6,21 +6,21 @@ // // Identification: src/network/service/tcp_connection.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// -#include #include +#include #include #include -#include "network/service/tcp_connection.h" +#include "common/macros.h" #include "network/service/connection_manager.h" #include "network/service/peloton_service.h" #include "network/service/rpc_type.h" -#include "common/macros.h" +#include "network/service/tcp_connection.h" namespace peloton { namespace network { @@ -152,7 +152,8 @@ void *Connection::ProcessMessage(void *connection) { // Get the hashcode of the rpc method uint64_t opcode = 0; - PELOTON_MEMCPY((char *)(&opcode), buf + HEADERLEN + TYPELEN, sizeof(opcode)); + PELOTON_MEMCPY((char *)(&opcode), buf + HEADERLEN + TYPELEN, + sizeof(opcode)); // Get the rpc method meta info: method descriptor RpcMethod *rpc_method = conn->GetRpcServer()->FindMethod(opcode); diff --git a/src/network/service/tcp_listener.cpp b/src/network/service/tcp_listener.cpp index 319266d1030..433efca4b42 100644 --- a/src/network/service/tcp_listener.cpp +++ b/src/network/service/tcp_listener.cpp @@ -6,14 +6,14 @@ // // Identification: src/network/service/tcp_listener.cpp // -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// #include "network/service/tcp_listener.h" -#include "network/service/tcp_connection.h" -#include "network/service/rpc_type.h" #include "network/service/connection_manager.h" +#include "network/service/rpc_type.h" +#include "network/service/tcp_connection.h" #include "common/logger.h" #include "common/macros.h" @@ -101,7 +101,8 @@ void Listener::Run(void *arg) { void Listener::AcceptConnCb(struct evconnlistener *listener, evutil_socket_t fd, struct sockaddr *address, UNUSED_ATTRIBUTE int socklen, void *ctx) { - PELOTON_ASSERT(listener != NULL && address != NULL && socklen >= 0 && ctx != NULL); + PELOTON_ASSERT(listener != NULL && address != NULL && socklen >= 0 && + ctx != NULL); LOG_TRACE("Server: connection received"); From 334f30303405cd448fb4d1313b7e5b4004a51555 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Thu, 14 Jun 2018 16:25:30 -0400 Subject: [PATCH 04/48] Fix ssl and write bugs --- script/installation/packages.sh | 5 ++- src/include/network/marshal.h | 14 ++++---- src/include/network/network_io_wrappers.h | 27 ++++++++++----- src/network/connection_handle.cpp | 16 ++++----- src/network/network_io_wrapper_factory.cpp | 39 ++++++++++------------ src/network/network_io_wrappers.cpp | 34 +++++++++---------- test/codegen/csv_scan_translator_test.cpp | 2 +- test/network/exception_test.cpp | 2 -- test/network/ssl_test.cpp | 1 - 9 files changed, 71 insertions(+), 69 deletions(-) diff --git a/script/installation/packages.sh b/script/installation/packages.sh index 080f000944a..dec631fab98 100755 --- a/script/installation/packages.sh +++ b/script/installation/packages.sh @@ -47,6 +47,9 @@ TF_TYPE="cpu" function install_protobuf3.4.0() { # Install Relevant tooling # Remove any old versions of protobuf + # Note: Protobuf 3.5+ PPA available Ubuntu Bionic(18.04) onwards - Should be used + # when we retire 16.04 too: https://launchpad.net/~maarten-fonville/+archive/ubuntu/protobuf + # This PPA unfortunately doesnt have Protobuf 3.5 for 16.04, but does for 14.04/18.04+ DISTRIB=$1 # ubuntu/fedora if [ "$DISTRIB" == "ubuntu" ]; then sudo apt-get --yes --force-yes remove --purge libprotobuf-dev protobuf-compiler @@ -63,7 +66,7 @@ function install_protobuf3.4.0() { wget -O protobuf-cpp-3.4.0.tar.gz https://github.com/google/protobuf/releases/download/v3.4.0/protobuf-cpp-3.4.0.tar.gz tar -xzf protobuf-cpp-3.4.0.tar.gz cd protobuf-3.4.0 - ./autogen.sh && ./configure && make -j4 && sudo make install && sudo ldconfig + ./autogen.sh && ./configure && make -j4 && sudo make install && sudo ldconfig || exit 1 cd .. # Cleanup rm -rf protobuf-3.4.0 protobuf-cpp-3.4.0.tar.gz diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index c40031b1bac..56d29e57bbb 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -51,7 +51,7 @@ struct Buffer { /** * @param bytes The amount of bytes to check between the cursor and the end - * of the buffer (defaults to 1) + * of the buffer (defaults to any) * @return Whether there is any more bytes between the cursor and * the end of the buffer */ @@ -71,7 +71,7 @@ struct Buffer { /** * @return Capacity of the buffer (not actual size) */ - inline constexpr size_t Capacity() const { return SOCKET_BUFFER_SIZE; } + inline size_t Capacity() const { return SOCKET_BUFFER_SIZE; } /** * Shift contents to align the current cursor with start of the buffer, @@ -139,7 +139,8 @@ class ReadBuffer : public Buffer { /** * Read a value of type T off of the buffer, advancing cursor by appropriate - * amount + * amount. Does NOT convert from network bytes order. It is the caller's + * responsibility to do so. * @tparam T type of value to read off. Preferably a primitive type * @return the value of type T */ @@ -185,7 +186,7 @@ class WriteBuffer : public Buffer { * maximum capacity minus the capacity already in use. * @return Remaining capacity */ - inline size_t RemainingCapacity() { return Capacity() - size_ + 1; } + inline size_t RemainingCapacity() { return Capacity() - size_; } /** * @param bytes Desired number of bytes to write @@ -194,7 +195,7 @@ class WriteBuffer : public Buffer { inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } /** - * Append the desired range into current buffer + * Append the desired range into current buffer. * @tparam InputIt iterator type. * @param first beginning of range * @param len length of range @@ -206,7 +207,8 @@ class WriteBuffer : public Buffer { } /** - * Append the given value into the current buffer + * Append the given value into the current buffer. Does NOT convert to + * network byte order. It is up to the caller to do so. * @tparam T input type * @param val value to write into buffer */ diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 66c453626d8..9502cd94e2d 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -38,6 +38,7 @@ class NetworkIoWrapper { friend class NetworkIoWrapperFactory; public: + virtual bool SslAble() const = 0; // TODO(Tianyu): Change and document after we refactor protocol handler virtual Transition FillReadBuffer() = 0; virtual Transition FlushWriteBuffer() = 0; @@ -51,14 +52,15 @@ class NetworkIoWrapper { std::shared_ptr &wbuf) : sock_fd_(sock_fd), rbuf_(std::move(rbuf)), - wbuf_(std::move(wbuf)), - conn_ssl_context_(nullptr) {} - // It is worth noting that because of the way we are reinterpret-casting - // between derived types, it is necessary that they share the same members. + wbuf_(std::move(wbuf)) {} + + DISALLOW_COPY(NetworkIoWrapper) + + NetworkIoWrapper(NetworkIoWrapper &&other) = default; + int sock_fd_; std::shared_ptr rbuf_; std::shared_ptr wbuf_; - SSL *conn_ssl_context_; }; /** @@ -69,6 +71,8 @@ class PosixSocketIoWrapper : public NetworkIoWrapper { PosixSocketIoWrapper(int sock_fd, std::shared_ptr rbuf, std::shared_ptr wbuf); + + inline bool SslAble() const override { return false; } Transition FillReadBuffer() override; Transition FlushWriteBuffer() override; inline Transition Close() override { @@ -82,14 +86,19 @@ class PosixSocketIoWrapper : public NetworkIoWrapper { */ class SslSocketIoWrapper : public NetworkIoWrapper { public: - // An SslSocketIoWrapper is always derived from a PosixSocketIoWrapper, - // as the handshake process happens over posix sockets. Use the method - // in NetworkIoWrapperFactory to get an SslSocketWrapper. - SslSocketIoWrapper() = delete; + // Realistically, an SslSocketIoWrapper is always derived from a + // PosixSocketIoWrapper, as the handshake process happens over posix sockets. + SslSocketIoWrapper(NetworkIoWrapper &&other, SSL *ssl) + : NetworkIoWrapper(std::move(other)), conn_ssl_context_(ssl) {} + inline bool SslAble() const override { return true; } Transition FillReadBuffer() override; Transition FlushWriteBuffer() override; Transition Close() override; + + private: + friend class NetworkIoWrapperFactory; + SSL *conn_ssl_context_; }; } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index d14837770b6..3826e59640a 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -129,8 +129,8 @@ DEF_TRANSITION_GRAPH ON(WAKEUP) SET_STATE_TO(PROCESS) AND_INVOKE(GetResult) ON(PROCEED) SET_STATE_TO(WRITE) AND_INVOKE(TryWrite) ON(NEED_READ) SET_STATE_TO(READ) AND_INVOKE(TryRead) - // Client connections are ignored while we wait on peloton - // to execute the query + // Client connections are ignored while we wait on peloton + // to execute the query ON(NEED_RESULT) SET_STATE_TO(PROCESS) AND_WAIT_ON_PELOTON ON(NEED_SSL_HANDSHAKE) SET_STATE_TO(SSL_INIT) AND_INVOKE(TrySslHandshake) END_STATE_DEF @@ -145,8 +145,8 @@ DEF_TRANSITION_GRAPH END_DEF // clang-format on - void ConnectionHandle::StateMachine::Accept(Transition action, - ConnectionHandle &connection) { +void ConnectionHandle::StateMachine::Accept(Transition action, + ConnectionHandle &connection) { Transition next = action; while (next != Transition::NONE) { transition_result result = Delta_(current_state_, next); @@ -162,12 +162,8 @@ END_DEF } ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) - : conn_handler_(handler) { - // We will always handle connections using posix until (potentially) first SSL - // handshake. - io_wrapper_ = - NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd); -} + : conn_handler_(handler), + io_wrapper_(NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd)) {} Transition ConnectionHandle::TryWrite() { for (; next_response_ < protocol_handler_->responses_.size(); diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index 383dc48dae9..e2542f3909e 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include #include "network/network_io_wrapper_factory.h" namespace peloton { @@ -30,40 +31,35 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( // Construct new wrapper by reusing buffers from the old one. // The old one will be deallocated as we replace the last reference to it - // in the reusable_wrappers_ map - auto reused_wrapper = it->second; - reused_wrapper->rbuf_->Reset(); - reused_wrapper->wbuf_->Reset(); - reused_wrapper->sock_fd_ = conn_fd; - reused_wrapper->conn_ssl_context_ = nullptr; - // It is not necessary to have an explicit cast here because the reused - // wrapper always use Posix methods, as we never update their type in the - // reusable wrappers map. + // in the reusable_wrappers_ map. We still need to explicitly call the + // constructor so the flags are set properly on the new file descriptor. + auto &reused_wrapper = it->second; + reused_wrapper = std::make_shared(conn_fd, + reused_wrapper->rbuf_, + reused_wrapper->wbuf_); return reused_wrapper; } Transition NetworkIoWrapperFactory::PerformSslHandshake( std::shared_ptr &io_wrapper) { - if (io_wrapper->conn_ssl_context_ == nullptr) { - // Initial handshake, the incoming type is a posix socket wrapper - auto *context = io_wrapper->conn_ssl_context_ = - SSL_new(PelotonServer::ssl_context); - // TODO(Tianyu): Is it the right thing here to throw exceptions? + SSL *context; + if (!io_wrapper->SslAble()) { + context = SSL_new(PelotonServer::ssl_context); if (context == nullptr) throw NetworkProcessException("ssl context for conn failed"); SSL_set_session_id_context(context, nullptr, 0); if (SSL_set_fd(context, io_wrapper->sock_fd_) == 0) throw NetworkProcessException("Failed to set ssl fd"); - - // ssl handshake is done, need to use new methods for the original wrappers; - // We do not update the type in the reusable wrappers map because it is not - // relevant. - io_wrapper.reset(reinterpret_cast(io_wrapper.get())); + io_wrapper = + std::make_shared(std::move(*io_wrapper), context); + } else { + auto ptr = std::dynamic_pointer_cast( + io_wrapper); + context = ptr->conn_ssl_context_; } // The wrapper already uses SSL methods. // Yuchen: "Post-connection verification?" - auto *context = io_wrapper->conn_ssl_context_; ERR_clear_error(); int ssl_accept_ret = SSL_accept(context); if (ssl_accept_ret > 0) return Transition::PROCEED; @@ -74,8 +70,7 @@ Transition NetworkIoWrapperFactory::PerformSslHandshake( return Transition::NEED_READ; case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; - default: - LOG_ERROR("SSL Error, error code %d", err); + default:LOG_ERROR("SSL Error, error code %d", err); return Transition::TERMINATE; } } diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index b90fd1a05dd..317a511e572 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -21,25 +21,25 @@ namespace peloton { namespace network { Transition NetworkIoWrapper::WritePacket(OutputPacket *pkt) { // Write Packet Header - if (pkt->skip_header_write) return Transition::PROCEED; + if (!pkt->skip_header_write) { + if (!wbuf_->HasSpaceFor(1 + sizeof(int32_t))) { + auto result = FlushWriteBuffer(); + if (FlushWriteBuffer() != Transition::PROCEED) + // Unable to flush buffer, socket presumably not ready for write + return result; + } - if (!wbuf_->HasSpaceFor(1 + sizeof(int32_t))) { - auto result = FlushWriteBuffer(); - if (FlushWriteBuffer() != Transition::PROCEED) - // Unable to flush buffer, socket presumably not ready for write - return result; + wbuf_->Append(static_cast(pkt->msg_type)); + if (!pkt->single_type_pkt) + // Need to convert bytes to network order + wbuf_->Append(htonl(pkt->len + sizeof(int32_t))); + pkt->skip_header_write = true; } - wbuf_->Append(static_cast(pkt->msg_type)); - if (!pkt->single_type_pkt) - // Need to convert bytes to network order - wbuf_->Append(htonl(pkt->len + sizeof(int32_t))); - pkt->skip_header_write = true; - // Write Packet Content for (size_t len = pkt->len; len != 0;) { - if (wbuf_->HasSpaceFor(pkt->len)) { - wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, pkt->len); + if (wbuf_->HasSpaceFor(len)) { + wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, len); break; } else { auto write_size = wbuf_->RemainingCapacity(); @@ -129,7 +129,7 @@ Transition SslSocketIoWrapper::FillReadBuffer() { // The SSL packet is partially loaded to the SSL buffer only, // More data is required in order to decode the wh`ole packet. case SSL_ERROR_WANT_READ: - return Transition::NEED_READ; + return result; case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; case SSL_ERROR_SYSCALL: @@ -146,7 +146,7 @@ Transition SslSocketIoWrapper::FillReadBuffer() { } Transition SslSocketIoWrapper::FlushWriteBuffer() { - while (wbuf_->Full()) { + while (wbuf_->HasMore()) { auto ret = wbuf_->WriteOutTo(conn_ssl_context_); switch (ret) { case SSL_ERROR_NONE: @@ -167,7 +167,7 @@ Transition SslSocketIoWrapper::FlushWriteBuffer() { throw NetworkProcessException("SSL write error"); } } - + wbuf_->Reset(); return Transition::PROCEED; } diff --git a/test/codegen/csv_scan_translator_test.cpp b/test/codegen/csv_scan_translator_test.cpp index 320db518117..87d029d0efa 100644 --- a/test/codegen/csv_scan_translator_test.cpp +++ b/test/codegen/csv_scan_translator_test.cpp @@ -36,7 +36,7 @@ class CSVScanTranslatorTest : public PelotonCodeGenTest { TEST_F(CSVScanTranslatorTest, IntCsvScan) { // The quoting character and a helper function to quote a given string const char quote = '"'; - const auto quote_string = [quote](std::string s) { + const auto quote_string = [=](std::string s) { return StringUtil::Format("%c%s%c", quote, s.c_str(), quote); }; diff --git a/test/network/exception_test.cpp b/test/network/exception_test.cpp index 1175de77da0..08ecc98c9a5 100644 --- a/test/network/exception_test.cpp +++ b/test/network/exception_test.cpp @@ -22,8 +22,6 @@ #include "network/protocol_handler_factory.h" #include "util/string_util.h" -#define NUM_THREADS 1 - namespace peloton { namespace test { diff --git a/test/network/ssl_test.cpp b/test/network/ssl_test.cpp index 00ed3582109..b9399ce7757 100644 --- a/test/network/ssl_test.cpp +++ b/test/network/ssl_test.cpp @@ -58,7 +58,6 @@ void *TestRoutine(int port) { pqxx::work txn1(C); - // basic test // create table and insert some data txn1.exec("DROP TABLE IF EXISTS employee;"); From 94e1cb54f1d5fdf114599e6ca409eff8af004bec Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Mon, 18 Jun 2018 17:18:17 -0400 Subject: [PATCH 05/48] Shorten travis build time --- .travis.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1f191a1f43b..521ede818df 100644 --- a/.travis.yml +++ b/.travis.yml @@ -75,16 +75,16 @@ script: # build - make -j4 # run tests - - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then make check -j4; fi - - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 make check -j4; fi +# - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then make check -j4; fi +# - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 make check -j4; fi # install peloton - make install # run psql tests - - bash ../script/testing/psql/psql_test.sh +# - bash ../script/testing/psql/psql_test.sh # run jdbc tests - - python ../script/validators/jdbc_validator.py +# - python ../script/validators/jdbc_validator.py # run junit tests - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then python ../script/testing/junit/run_junit.py; fi - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 python ../script/testing/junit/run_junit.py; fi # upload coverage info - - if [[ $COVERALLS == 'On' ]]; then make coveralls; fi +# - if [[ $COVERALLS == 'On' ]]; then make coveralls; fi From b10e01ccc5ca68a84dea4ba3451a3a887970aefb Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 19 Jun 2018 11:31:17 -0400 Subject: [PATCH 06/48] Buffer reset and shutdown ordering --- src/include/network/network_io_wrappers.h | 5 ++++- src/network/connection_handle.cpp | 11 ++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 9502cd94e2d..b73c1c4bb67 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -52,7 +52,10 @@ class NetworkIoWrapper { std::shared_ptr &wbuf) : sock_fd_(sock_fd), rbuf_(std::move(rbuf)), - wbuf_(std::move(wbuf)) {} + wbuf_(std::move(wbuf)) { + rbuf_->Reset(); + wbuf_->Reset(); + } DISALLOW_COPY(NetworkIoWrapper) diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index 3826e59640a..dd7787c4f41 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -227,9 +227,18 @@ Transition ConnectionHandle::TrySslHandshake() { Transition ConnectionHandle::CloseConnection() { LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); + // TODO(Tianyu): Handle close failure + io_wrapper_->Close(); // Remove listening event + // Only after the connection is closed is it safe to remove events, + // after this point no object in the system has reference to this + // connection handle and we will need to destruct and exit. conn_handler_->UnregisterEvent(network_event_); - io_wrapper_->Close(); + conn_handler_->UnregisterEvent(workpool_event_); + // This object is essentially managed by libevent (which unfortunately does + // not accept shared_ptrs.) and thus as we shut down we need to manually + // deallocate this object. + delete this; return Transition::NONE; } } // namespace network From 79bffbad853fe46be3f6f71d30aa2a042ffc1c48 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 19 Jun 2018 13:21:42 -0400 Subject: [PATCH 07/48] Add in termination state --- src/include/network/connection_handle.h | 2 +- src/include/network/network_io_wrappers.h | 4 ++-- src/include/network/network_state.h | 1 - src/network/connection_handle.cpp | 17 ++++++++++++----- src/network/network_io_wrapper_factory.cpp | 4 ++-- src/network/network_io_wrappers.cpp | 3 ++- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 959d9af970e..84db833f102 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -98,7 +98,7 @@ class ConnectionHandle { Transition Process(); Transition GetResult(); Transition TrySslHandshake(); - Transition CloseConnection(); + Transition TryCloseConnection(); /** * Updates the event flags of the network event. This configures how the diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index b73c1c4bb67..1b100475ffd 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -80,7 +80,7 @@ class PosixSocketIoWrapper : public NetworkIoWrapper { Transition FlushWriteBuffer() override; inline Transition Close() override { peloton_close(sock_fd_); - return Transition::NONE; + return Transition::PROCEED; } }; @@ -104,4 +104,4 @@ class SslSocketIoWrapper : public NetworkIoWrapper { SSL *conn_ssl_context_; }; } // namespace network -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/network/network_state.h b/src/include/network/network_state.h index b580e0ff457..96373dbe919 100644 --- a/src/include/network/network_state.h +++ b/src/include/network/network_state.h @@ -36,7 +36,6 @@ enum class Transition { WAKEUP, PROCEED, NEED_READ, - // TODO(tianyu) generalize this symbol, this is currently only used in process NEED_RESULT, TERMINATE, NEED_SSL_HANDSHAKE, diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index dd7787c4f41..e87eabd74c3 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -103,7 +103,7 @@ namespace { } #define END_STATE_DEF \ - ON(TERMINATE) SET_STATE_TO(CLOSING) AND_INVOKE(CloseConnection) END_DEF + ON(TERMINATE) SET_STATE_TO(CLOSING) AND_INVOKE(TryCloseConnection) END_DEF } // namespace // clang-format off @@ -142,6 +142,12 @@ DEF_TRANSITION_GRAPH ON(NEED_WRITE) SET_STATE_TO(WRITE) AND_WAIT_ON_WRITE ON(PROCEED) SET_STATE_TO(PROCESS) AND_INVOKE(Process) END_STATE_DEF + + DEFINE_STATE(CLOSING) + ON(WAKEUP) SET_STATE_TO(CLOSING) AND_INVOKE(TryCloseConnection) + ON(NEED_READ) SET_STATE_TO(WRITE) AND_WAIT_ON_READ + ON(NEED_WRITE) SET_STATE_TO(WRITE) AND_WAIT_ON_WRITE + END_STATE_DEF END_DEF // clang-format on @@ -155,7 +161,7 @@ void ConnectionHandle::StateMachine::Accept(Transition action, next = result.second(connection); } catch (NetworkProcessException &e) { LOG_ERROR("%s\n", e.what()); - connection.CloseConnection(); + connection.TryCloseConnection(); return; } } @@ -199,7 +205,7 @@ Transition ConnectionHandle::Process() { case ProcessResult::PROCESSING: return Transition::NEED_RESULT; case ProcessResult::TERMINATE: - return Transition::TERMINATE; + throw NetworkProcessException("Error when processing"); case ProcessResult::NEED_SSL_HANDSHAKE: return Transition::NEED_SSL_HANDSHAKE; default: @@ -225,10 +231,11 @@ Transition ConnectionHandle::TrySslHandshake() { io_wrapper_); } -Transition ConnectionHandle::CloseConnection() { +Transition ConnectionHandle::TryCloseConnection() { LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); // TODO(Tianyu): Handle close failure - io_wrapper_->Close(); + Transition close = io_wrapper_->Close(); + if (close != Transition::PROCEED) return close; // Remove listening event // Only after the connection is closed is it safe to remove events, // after this point no object in the system has reference to this diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index e2542f3909e..2c675ea0d63 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -70,8 +70,8 @@ Transition NetworkIoWrapperFactory::PerformSslHandshake( return Transition::NEED_READ; case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; - default:LOG_ERROR("SSL Error, error code %d", err); - return Transition::TERMINATE; + default: + throw NetworkProcessException("SSL Error, error code" + std::to_string(err)); } } } // namespace network diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index 317a511e572..80bad466c0c 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -178,6 +178,7 @@ Transition SslSocketIoWrapper::Close() { int err = SSL_get_error(conn_ssl_context_, ret); switch (err) { case SSL_ERROR_WANT_WRITE: + return Transition::NEED_WRITE; case SSL_ERROR_WANT_READ: // More work to do before shutdown return Transition::NEED_READ; @@ -192,7 +193,7 @@ Transition SslSocketIoWrapper::Close() { SSL_free(conn_ssl_context_); conn_ssl_context_ = nullptr; peloton_close(sock_fd_); - return Transition::NONE; + return Transition::PROCEED; } } // namespace network From 96dba94f3c6dc3e1723d9027ac124c5d77b5d8b0 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 19 Jun 2018 13:49:38 -0400 Subject: [PATCH 08/48] Revert Travis changes --- .travis.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index 521ede818df..1f191a1f43b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -75,16 +75,16 @@ script: # build - make -j4 # run tests -# - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then make check -j4; fi -# - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 make check -j4; fi + - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then make check -j4; fi + - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 make check -j4; fi # install peloton - make install # run psql tests -# - bash ../script/testing/psql/psql_test.sh + - bash ../script/testing/psql/psql_test.sh # run jdbc tests -# - python ../script/validators/jdbc_validator.py + - python ../script/validators/jdbc_validator.py # run junit tests - if [[ $TRAVIS_OS_NAME != 'osx' ]]; then python ../script/testing/junit/run_junit.py; fi - if [[ $TRAVIS_OS_NAME == 'osx' ]]; then ASAN_OPTIONS=detect_container_overflow=0 python ../script/testing/junit/run_junit.py; fi # upload coverage info -# - if [[ $COVERALLS == 'On' ]]; then make coveralls; fi + - if [[ $COVERALLS == 'On' ]]; then make coveralls; fi From aa9cc366b86e9572f839070184491a3853a8733d Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 20 Jun 2018 13:29:32 -0400 Subject: [PATCH 09/48] Start new protocol layer --- .../network/connection_dispatcher_task.h | 2 +- src/include/network/connection_handle.h | 2 +- src/include/network/marshal.h | 37 ++++-- src/include/network/network_types.h | 54 ++++++++ src/include/network/peloton_server.h | 2 +- .../network/postgres_network_commands.h | 121 ++++++++++++++++++ src/include/network/postgres_wire_protocol.h | 75 +++++++++++ src/include/network/wire_protocol.h | 29 +++++ 8 files changed, 309 insertions(+), 13 deletions(-) create mode 100644 src/include/network/network_types.h create mode 100644 src/include/network/postgres_network_commands.h create mode 100644 src/include/network/postgres_wire_protocol.h create mode 100644 src/include/network/wire_protocol.h diff --git a/src/include/network/connection_dispatcher_task.h b/src/include/network/connection_dispatcher_task.h index 0b97147622a..6e89ef3cde6 100644 --- a/src/include/network/connection_dispatcher_task.h +++ b/src/include/network/connection_dispatcher_task.h @@ -15,7 +15,7 @@ #include "common/notifiable_task.h" #include "concurrency/epoch_manager_factory.h" #include "connection_handler_task.h" -#include "network_state.h" +#include "network_types.h" namespace peloton { namespace network { diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 84db833f102..92d85d2d36a 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -33,7 +33,7 @@ #include "marshal.h" #include "network/connection_handler_task.h" #include "network/network_io_wrappers.h" -#include "network_state.h" +#include "network_types.h" #include "protocol_handler.h" #include diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 56d29e57bbb..2efd5370992 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -20,13 +20,14 @@ #include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" -#include "network/network_state.h" +#include "common/exception.h" +#include "network/network_types.h" +#include "network/postgres_network_commands.h" #define BUFFER_INIT_SIZE 100 namespace peloton { namespace network { - /** * A plain old buffer with a movable cursor, the meaning of which is dependent * on the use case. @@ -47,6 +48,8 @@ struct Buffer { inline void Reset() { size_ = 0; offset_ = 0; + buf_.resize(SOCKET_BUFFER_SIZE); + buf_.shrink_to_fit(); } /** @@ -71,7 +74,7 @@ struct Buffer { /** * @return Capacity of the buffer (not actual size) */ - inline size_t Capacity() const { return SOCKET_BUFFER_SIZE; } + inline size_t Capacity() const { return capacity_; } /** * Shift contents to align the current cursor with start of the buffer, @@ -84,8 +87,15 @@ struct Buffer { offset_ = 0; } - // TODO(Tianyu): Make these protected once we refactor protocol handler - size_t size_ = 0, offset_ = 0; + inline void ExpandTo(size_t size) { + // We should never need to trim down the size past SOCKET_BUFFER_SIZE + PELOTON_ASSERT(size > SOCKET_BUFFER_SIZE); + capacity_ = size; + buf_.resize(capacity_); + } + + protected: + size_t size_ = 0, offset_ = 0, capacity_ = SOCKET_BUFFER_SIZE; ByteBuf buf_; }; @@ -115,7 +125,7 @@ class ReadBuffer : public Buffer { inline int FillBufferFrom(int fd) { ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); if (bytes_read > 0) size_ += bytes_read; - return (int)bytes_read; + return (int) bytes_read; } /** @@ -125,6 +135,13 @@ class ReadBuffer : public Buffer { */ inline size_t BytesAvailable() { return size_ - offset_; } + // TODO(Tianyu): Document + inline void Read(size_t bytes, ByteBuf::const_iterator &begin, ByteBuf::const_iterator &end) { + begin = buf_.begin() + offset_; + end = begin + bytes; + offset_ += bytes; + } + /** * Read the given number of bytes into destination, advancing cursor by that * number @@ -144,7 +161,7 @@ class ReadBuffer : public Buffer { * @tparam T type of value to read off. Preferably a primitive type * @return the value of type T */ - template + template inline T ReadValue() { T result; Read(sizeof(result), &result); @@ -178,7 +195,7 @@ class WriteBuffer : public Buffer { inline int WriteOutTo(int fd) { ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); if (bytes_written > 0) offset_ += bytes_written; - return (int)bytes_written; + return (int) bytes_written; } /** @@ -200,7 +217,7 @@ class WriteBuffer : public Buffer { * @param first beginning of range * @param len length of range */ - template + template inline void Append(InputIt first, size_t len) { std::copy(first, first + len, std::begin(buf_) + size_); size_ += len; @@ -212,7 +229,7 @@ class WriteBuffer : public Buffer { * @tparam T input type * @param val value to write into buffer */ - template + template inline void Append(T val) { Append(reinterpret_cast(&val), sizeof(T)); } diff --git a/src/include/network/network_types.h b/src/include/network/network_types.h new file mode 100644 index 00000000000..372f626a51a --- /dev/null +++ b/src/include/network/network_types.h @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// network_types.h +// +// Identification: src/include/network/network_types.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace peloton { +namespace network { +/** + * States used by ConnectionHandle::StateMachine. + * @see ConnectionHandle::StateMachine + */ +enum class ConnState { + READ, // State that reads data from the network + WRITE, // State the writes data to the network + PROCESS, // State that runs the network protocol on received data + CLOSING, // State for closing the client connection + SSL_INIT, // State to flush out responses and doing (Real) SSL handshake +}; + +/** + * A transition is used to signal the result of an action to + * ConnectionHandle::StateMachine + * @see ConnectionHandle::StateMachine + */ +enum class Transition { + NONE, + WAKEUP, + PROCEED, + NEED_READ, + NEED_RESULT, + TERMINATE, + NEED_SSL_HANDSHAKE, + NEED_WRITE +}; + +enum class ResponseProtocol { + // No response required (for intermediate messgaes such as parse, bind, etc.) + NO, + // PSQL + SIMPLE, + // JDBC, PQXX, etc. + EXTENDED +}; +} // namespace network +} // namespace peloton diff --git a/src/include/network/peloton_server.h b/src/include/network/peloton_server.h index e0baed54ef1..eac292a6e1a 100644 --- a/src/include/network/peloton_server.h +++ b/src/include/network/peloton_server.h @@ -34,7 +34,7 @@ #include "common/logger.h" #include "common/notifiable_task.h" #include "connection_dispatcher_task.h" -#include "network_state.h" +#include "network_types.h" #include "protocol_handler.h" #include diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h new file mode 100644 index 00000000000..e0962ed9c90 --- /dev/null +++ b/src/include/network/postgres_network_commands.h @@ -0,0 +1,121 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_network_commands.h +// +// Identification: src/include/network/postgres_network_commands.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once +#include "common/internal_types.h" +#include "common/logger.h" +#include "common/macros.h" +#include "network/network_types.h" +#include "network/postgres_protocol_handler.h" +#include "traffic_cop/traffic_cop.h" +#include "network/marshal.h" + +#define DEFINE_COMMAND(name, protocol_type) \ +class name : public PostgresNetworkCommand { \ + public: \ + name(PostgresRawInputPacket &input_packet) \ + : PostgresNetworkCommand(std::move(input_packet), protocol_type) {} \ + virtual Transition Exec(PostgresProtocolHandler &handler) override; \ +} + +namespace peloton { +namespace network { + +struct PostgresRawInputPacket { + NetworkMessageType msg_type_ = NetworkMessageType::NULL_COMMAND; + size_t len_ = 0; + ByteBuf::const_iterator packet_head_, packet_tail_; + bool header_parsed_ = false; + + PostgresRawInputPacket() = default; + PostgresRawInputPacket(const PostgresRawInputPacket &) = default; + PostgresRawInputPacket(PostgresRawInputPacket &&) = default; + + inline void Clear() { + msg_type_ = NetworkMessageType::NULL_COMMAND; + len_ = 0; + header_parsed_ = false; + } + + int GetInt(uint8_t len) { + int value = 0; + std::copy(packet_head_, packet_head_ + len, + reinterpret_cast(&value)); + switch (len) { + case 1:break; + case 2:value = ntohs(value); + break; + case 4:value = ntohl(value); + break; + default: + throw NetworkProcessException( + "Error when de-serializing: Invalid int size"); + } + packet_head_ += len; + return value; + } + + inline uchar GetByte() { return (uchar) GetInt(1); } + + inline void GetBytes(size_t len, ByteBuf &result) { + result.insert(std::end(result), packet_head_, packet_head_ + len); + packet_head_ += len; + } + + std::string GetString(size_t len) { + // TODO(Tianyu): This looks broken, some broken-looking code depends on it + // though + if (len == 0) return ""; + // Nul character at end + auto result = std::string(packet_head_, packet_head_ + (len - 1)); + packet_head_ += len; + return result; + } + + std::string GetString() { + // Find nul-terminator + auto find_itr = std::find(packet_head_, packet_tail_, 0); + if (find_itr == packet_tail_) + throw NetworkProcessException("Expected nil at end of packet, none found"); + auto result = std::string(packet_head_, find_itr); + packet_head_ = find_itr + 1; + return result; + } +}; + +class PostgresNetworkCommand { + public: + virtual Transition Exec(PostgresProtocolHandler &handler) = 0; + protected: + PostgresNetworkCommand(PostgresRawInputPacket input_packet, + ResponseProtocol response_protocol) + : input_packet_(input_packet), + response_protocol_(response_protocol) {} + + PostgresRawInputPacket input_packet_; + const ResponseProtocol response_protocol_; +}; + +// TODO(Tianyu): Fix response types +DEFINE_COMMAND(SslInitCommand, ResponseProtocol::SIMPLE); +DEFINE_COMMAND(StartupCommand, ResponseProtocol::SIMPLE); +DEFINE_COMMAND(SimpleQueryCommand, ResponseProtocol::SIMPLE); +DEFINE_COMMAND(ParseCommand, ResponseProtocol::NO); +DEFINE_COMMAND(BindCommand, ResponseProtocol::NO); +DEFINE_COMMAND(DescribeCommand, ResponseProtocol::NO); +DEFINE_COMMAND(ExecuteCommand, ResponseProtocol::EXTENDED); +DEFINE_COMMAND(SyncCommand, ResponseProtocol::SIMPLE); +DEFINE_COMMAND(CloseCommand, ResponseProtocol::NO); +DEFINE_COMMAND(TerminateCommand, ResponseProtocol::NO); +DEFINE_COMMAND(NullCommand, ResponseProtocol::NO); + +} // namespace network +} // namespace peloton diff --git a/src/include/network/postgres_wire_protocol.h b/src/include/network/postgres_wire_protocol.h new file mode 100644 index 00000000000..be01f337db9 --- /dev/null +++ b/src/include/network/postgres_wire_protocol.h @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_wire_protocol.h +// +// Identification: src/include/network/postgres_wire_protocol.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once +#include "common/logger.h" +#include "network/wire_protocol.h" +#include "network/postgres_network_commands.h" + +#define SSL_MESSAGE_VERNO 80877103 +#define PROTO_MAJOR_VERSION(x) ((x) >> 16) + +namespace peloton { +namespace network { + +class PostgresWireProtocol : public WireProtocol { + public: + Transition Process(ReadBuffer &in, + WriteBuffer &out, + size_t thread_id) override; + private: + bool startup_ = true; + PostgresRawInputPacket curr_input_packet_; + + PostgresNetworkCommand PacketToCommand() { + if (startup_) { + int32_t proto_version = curr_input_packet_.GetInt(4); + LOG_INFO("protocol version: %d", proto_version); + if (proto_version == SSL_MESSAGE_VERNO) { + return SslInitCommand(curr_input_packet_); + } + } + curr_input_packet_.Clear(); + } + + bool BuildPacket(ReadBuffer &in) { + if (!ReadPacketHeader(in)) return false; + if (!in.HasMore(curr_input_packet_.len_)) return false; + in.Read(curr_input_packet_.len_, + curr_input_packet_.packet_head_, + curr_input_packet_.packet_tail_); + return true; + } + + bool ReadPacketHeader(ReadBuffer &in) { + if (curr_input_packet_.header_parsed_) return true; + // Header format: 1 byte message type (only if non-startup) + // + 4 byte message size (inclusive of these 4 bytes) + size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); + // Make sure the entire header is readable + if (!in.HasMore(header_size)) return false; + + // The header is ready to be read, fill in fields accordingly + if (!startup_) + curr_input_packet_.msg_type_ = in.ReadValue(); + curr_input_packet_.len_ = ntohl(in.ReadValue()) - sizeof(int32_t); + // Extend the buffer as needed + if (curr_input_packet_.len_ > in.Capacity()) { + LOG_INFO("Extended Buffer size required for packet of size %ld", + curr_input_packet_.len_); + in.ExpandTo(curr_input_packet_.len_); + } + curr_input_packet_.header_parsed_ = true; + return true; + } +}; +} // namespace peloton +} // namespace network \ No newline at end of file diff --git a/src/include/network/wire_protocol.h b/src/include/network/wire_protocol.h new file mode 100644 index 00000000000..2629d7caced --- /dev/null +++ b/src/include/network/wire_protocol.h @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// wire_protocol.h +// +// Identification: src/include/network/wire_protocol.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once +#include +#include "network/marshal.h" + +namespace peloton { +namespace network { + +class WireProtocol { + public: + // TODO(Tianyu): What the hell is this thread_id thingy + virtual Transition Process(ReadBuffer &in, + WriteBuffer &out, + size_t thread_id) = 0; + +}; + +} // namespace network +} // namespace peloton \ No newline at end of file From d7bf5ed2b0ef74e79e8e2840f525161b20e8371f Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Thu, 21 Jun 2018 17:24:27 -0400 Subject: [PATCH 10/48] Restructure serialization logic to use builder-like structure. --- src/include/common/internal_types.h | 2 +- src/include/network/connection_handle.h | 1 + src/include/network/marshal.h | 205 +++++++++++++++--- .../network/network_io_wrapper_factory.h | 2 +- .../network/postgres_network_commands.h | 82 ++----- src/include/network/postgres_wire_protocol.h | 147 ++++++++++--- src/include/network/wire_protocol.h | 4 +- src/network/connection_handle.cpp | 5 +- src/network/network_io_wrapper_factory.cpp | 2 +- src/network/network_io_wrappers.cpp | 8 +- src/network/postgres_network_commands.cpp | 100 +++++++++ src/network/postgres_protocol_handler.cpp | 4 +- 12 files changed, 427 insertions(+), 135 deletions(-) create mode 100644 src/network/postgres_network_commands.cpp diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 22598226407..449bff3e373 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1419,7 +1419,7 @@ typedef std::map> column_map_type; //===--------------------------------------------------------------------===// // Wire protocol typedefs //===--------------------------------------------------------------------===// -#define SOCKET_BUFFER_SIZE 8192 +#define SOCKET_BUFFER_CAPACITY 8192 /* byte type */ typedef unsigned char uchar; diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 92d85d2d36a..f82c838ada5 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -190,6 +190,7 @@ class ConnectionHandle { StateMachine state_machine_; struct event *network_event_ = nullptr, *workpool_event_ = nullptr; std::unique_ptr protocol_handler_ = nullptr; + // TODO(Tianyu): Remove tcop from here in later refactor tcop::TrafficCop tcop_; // TODO(Tianyu): Put this into protocol handler in a later refactor unsigned int next_response_ = 0; diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 2efd5370992..bee00461039 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -25,7 +25,6 @@ #include "network/postgres_network_commands.h" #define BUFFER_INIT_SIZE 100 - namespace peloton { namespace network { /** @@ -35,12 +34,14 @@ namespace network { * The buffer has a fix capacity and one can write a variable amount of * meaningful bytes into it. We call this amount "size" of the buffer. */ -struct Buffer { +class Buffer { public: /** * Instantiates a new buffer and reserve default many bytes. */ - inline Buffer() { buf_.reserve(SOCKET_BUFFER_SIZE); } + inline Buffer(size_t capacity) : capacity_(capacity) { + buf_.reserve(capacity); + } /** * Reset the buffer pointer and clears content @@ -48,8 +49,6 @@ struct Buffer { inline void Reset() { size_ = 0; offset_ = 0; - buf_.resize(SOCKET_BUFFER_SIZE); - buf_.shrink_to_fit(); } /** @@ -87,16 +86,12 @@ struct Buffer { offset_ = 0; } - inline void ExpandTo(size_t size) { - // We should never need to trim down the size past SOCKET_BUFFER_SIZE - PELOTON_ASSERT(size > SOCKET_BUFFER_SIZE); - capacity_ = size; - buf_.resize(capacity_); - } - - protected: - size_t size_ = 0, offset_ = 0, capacity_ = SOCKET_BUFFER_SIZE; + // TODO(Tianyu): Fix this after protocol refactor +// protected: + size_t size_ = 0, offset_ = 0, capacity_; ByteBuf buf_; + private: + friend class WriteQueue; }; /** @@ -104,6 +99,8 @@ struct Buffer { */ class ReadBuffer : public Buffer { public: + inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} /** * Read as many bytes as possible using SSL read * @param context SSL context to read from @@ -128,6 +125,11 @@ class ReadBuffer : public Buffer { return (int) bytes_read; } + inline void FillBufferFrom(ReadBuffer &other, size_t size) { + other.Read(size, &buf_[size_]); + size_ += size; + } + /** * The number of bytes available to be consumed (i.e. meaningful bytes after * current read cursor) @@ -135,13 +137,6 @@ class ReadBuffer : public Buffer { */ inline size_t BytesAvailable() { return size_ - offset_; } - // TODO(Tianyu): Document - inline void Read(size_t bytes, ByteBuf::const_iterator &begin, ByteBuf::const_iterator &end) { - begin = buf_.begin() + offset_; - end = begin + bytes; - offset_ += bytes; - } - /** * Read the given number of bytes into destination, advancing cursor by that * number @@ -154,6 +149,42 @@ class ReadBuffer : public Buffer { offset_ += bytes; } + inline int ReadInt(uint8_t len) { + switch (len) { + case 1:return ReadRawValue(); + case 2:return ntohs(ReadRawValue()); + case 4:return ntohl(ReadRawValue()); + default: + throw NetworkProcessException( + "Error when de-serializing: Invalid int size"); + } + } + + // Inclusive of nul-terminator + inline std::string ReadString(size_t len) { + if (len == 0) return ""; + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + offset_ + (len - 1)); + offset_ += len; + return result; + } + + // Read until nul terminator + inline std::string ReadString() { + // search for the nul terminator + for (size_t i = offset_; i < size_; i++) { + if (buf_[i] == 0) { + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + i); + // +1 because we want to skip nul + offset_ = i + 1; + return result; + } + } + // No nul terminator found + throw NetworkProcessException("Expected nil in read buffer, none found"); + } + /** * Read a value of type T off of the buffer, advancing cursor by appropriate * amount. Does NOT convert from network bytes order. It is the caller's @@ -162,7 +193,7 @@ class ReadBuffer : public Buffer { * @return the value of type T */ template - inline T ReadValue() { + inline T ReadRawValue() { T result; Read(sizeof(result), &result); return result; @@ -174,6 +205,8 @@ class ReadBuffer : public Buffer { */ class WriteBuffer : public Buffer { public: + inline WriteBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} /** * Write as many bytes as possible using SSL write * @param context SSL context to write out to @@ -213,13 +246,13 @@ class WriteBuffer : public Buffer { /** * Append the desired range into current buffer. - * @tparam InputIt iterator type. - * @param first beginning of range - * @param len length of range + * @param src beginning of range + * @param len length of range, in bytes */ - template - inline void Append(InputIt first, size_t len) { - std::copy(first, first + len, std::begin(buf_) + size_); + inline void AppendRaw(const void *src, size_t len) { + if (len == 0) return; + auto bytes_src = reinterpret_cast(src); + std::copy(bytes_src, bytes_src + len, std::begin(buf_) + size_); size_ += len; } @@ -230,11 +263,123 @@ class WriteBuffer : public Buffer { * @param val value to write into buffer */ template - inline void Append(T val) { - Append(reinterpret_cast(&val), sizeof(T)); + inline void AppendRaw(T val) { + AppendRaw(&val, sizeof(T)); } }; +class WriteQueue { + friend class NetworkIoWrapper; + public: + inline WriteQueue() { + Reset(); + } + + inline void Reset() { + buffers_.resize(1); + flush_ = false; + if (buffers_[0] == nullptr) + buffers_[0] = std::make_shared(); + else + buffers_[0]->Reset(); + } + + inline void WriteSingleBytePacket(NetworkMessageType type) { + // No active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + BufferWriteRawValue(type); + } + + inline WriteQueue &BeginPacket(NetworkMessageType type) { + // No active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + BufferWriteRawValue(type); + // Remember the size field since we will need to modify it as we go along. + // It is important that our size field is contiguous and not broken between + // two buffers. + BufferWriteRawValue(0, false); + WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); + curr_packet_len_ = + reinterpret_cast(&tail.buf_[tail.size_ - sizeof(int32_t)]); + return *this; + } + + inline WriteQueue &AppendRaw(const void *src, size_t len) { + BufferWriteRaw(src, len); + // Add the size field to the len of the packet. Be mindful of byte + // ordering. We switch to network ordering only when the packet is finished + *curr_packet_len_ += len; + return *this; + } + + template + inline WriteQueue &AppendRawValue(T val) { + return AppendRaw(&val, sizeof(T)); + } + + inline WriteQueue &AppendInt(uint8_t len, uint32_t val) { + int32_t result; + switch (len) { + case 1: + result = val; + break; + case 2: + result = htons(val); + break; + case 4: + result = htonl(val); + break; + default: + throw NetworkProcessException("Error constructing packet: invalid int size"); + } + return AppendRaw(&result, len); + } + + inline WriteQueue &AppendString(const std::string &str, bool nul_terminate = true) { + return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); + } + + inline void EndPacket() { + PELOTON_ASSERT(curr_packet_len_ != nullptr); + // Switch to network byte ordering, add the 4 bytes of size field + *curr_packet_len_ = htonl(*curr_packet_len_ + sizeof(int32_t)); + curr_packet_len_ = nullptr; + } + + inline WriteQueue &ForceFlush() { + flush_ = true; + return *this; + } + + inline bool ShouldFlush() { return flush_ || buffers_.size() > 1; } + + private: + + void BufferWriteRaw(const void *src, size_t len, bool breakup = true) { + WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); + if (tail.HasSpaceFor(len)) + tail.AppendRaw(src, len); + else { + // Only write partially if we are allowed to + size_t written = breakup ? tail.RemainingCapacity() : 0; + tail.AppendRaw(src, written); + buffers_.push_back(std::make_shared()); + BufferWriteRaw(reinterpret_cast(src) + written, len - written); + } + } + + template + inline void BufferWriteRawValue(T val, bool breakup = true) { + BufferWriteRaw(&val, sizeof(T), breakup); + } + + std::vector> buffers_; + bool flush_ = false; + // In network byte order. + uint32_t *curr_packet_len_ = nullptr; + +}; + class InputPacket { public: NetworkMessageType msg_type; // header diff --git a/src/include/network/network_io_wrapper_factory.h b/src/include/network/network_io_wrapper_factory.h index 979e6a18afd..d5170fe4202 100644 --- a/src/include/network/network_io_wrapper_factory.h +++ b/src/include/network/network_io_wrapper_factory.h @@ -57,7 +57,7 @@ class NetworkIoWrapperFactory { * NEED_DATA when the SSL handshake is partially done due to network * latency */ - Transition PerformSslHandshake(std::shared_ptr &io_wrapper); + Transition TryUseSsl(std::shared_ptr &io_wrapper); private: std::unordered_map> reusable_wrappers_; diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index e0962ed9c90..23094d3acd5 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -14,98 +14,56 @@ #include "common/logger.h" #include "common/macros.h" #include "network/network_types.h" -#include "network/postgres_protocol_handler.h" #include "traffic_cop/traffic_cop.h" #include "network/marshal.h" -#define DEFINE_COMMAND(name, protocol_type) \ -class name : public PostgresNetworkCommand { \ - public: \ - name(PostgresRawInputPacket &input_packet) \ - : PostgresNetworkCommand(std::move(input_packet), protocol_type) {} \ - virtual Transition Exec(PostgresProtocolHandler &handler) override; \ +#define DEFINE_COMMAND(name, protocol_type) \ +class name : public PostgresNetworkCommand { \ + public: \ + explicit name(PostgresRawInputPacket &input_packet) \ + : PostgresNetworkCommand(std::move(input_packet), protocol_type) {} \ + virtual Transition Exec(PostgresWireProtocol &, WriteQueue &, size_t) override; \ } namespace peloton { namespace network { -struct PostgresRawInputPacket { +class PostgresWireProtocol; + +struct PostgresInputPacket { NetworkMessageType msg_type_ = NetworkMessageType::NULL_COMMAND; size_t len_ = 0; - ByteBuf::const_iterator packet_head_, packet_tail_; - bool header_parsed_ = false; + std::shared_ptr buf_; + bool header_parsed_ = false, extended_ = false; - PostgresRawInputPacket() = default; - PostgresRawInputPacket(const PostgresRawInputPacket &) = default; - PostgresRawInputPacket(PostgresRawInputPacket &&) = default; + PostgresInputPacket() = default; + PostgresInputPacket(const PostgresInputPacket &) = default; + PostgresInputPacket(PostgresInputPacket &&) = default; inline void Clear() { msg_type_ = NetworkMessageType::NULL_COMMAND; len_ = 0; + buf_ = nullptr; header_parsed_ = false; } - - int GetInt(uint8_t len) { - int value = 0; - std::copy(packet_head_, packet_head_ + len, - reinterpret_cast(&value)); - switch (len) { - case 1:break; - case 2:value = ntohs(value); - break; - case 4:value = ntohl(value); - break; - default: - throw NetworkProcessException( - "Error when de-serializing: Invalid int size"); - } - packet_head_ += len; - return value; - } - - inline uchar GetByte() { return (uchar) GetInt(1); } - - inline void GetBytes(size_t len, ByteBuf &result) { - result.insert(std::end(result), packet_head_, packet_head_ + len); - packet_head_ += len; - } - - std::string GetString(size_t len) { - // TODO(Tianyu): This looks broken, some broken-looking code depends on it - // though - if (len == 0) return ""; - // Nul character at end - auto result = std::string(packet_head_, packet_head_ + (len - 1)); - packet_head_ += len; - return result; - } - - std::string GetString() { - // Find nul-terminator - auto find_itr = std::find(packet_head_, packet_tail_, 0); - if (find_itr == packet_tail_) - throw NetworkProcessException("Expected nil at end of packet, none found"); - auto result = std::string(packet_head_, find_itr); - packet_head_ = find_itr + 1; - return result; - } }; class PostgresNetworkCommand { public: - virtual Transition Exec(PostgresProtocolHandler &handler) = 0; + virtual Transition Exec(PostgresWireProtocol &protocol_obj, + WriteQueue &out, + size_t thread_id) = 0; protected: - PostgresNetworkCommand(PostgresRawInputPacket input_packet, + PostgresNetworkCommand(PostgresInputPacket input_packet, ResponseProtocol response_protocol) : input_packet_(input_packet), response_protocol_(response_protocol) {} - PostgresRawInputPacket input_packet_; + PostgresInputPacket input_packet_; const ResponseProtocol response_protocol_; }; // TODO(Tianyu): Fix response types -DEFINE_COMMAND(SslInitCommand, ResponseProtocol::SIMPLE); DEFINE_COMMAND(StartupCommand, ResponseProtocol::SIMPLE); DEFINE_COMMAND(SimpleQueryCommand, ResponseProtocol::SIMPLE); DEFINE_COMMAND(ParseCommand, ResponseProtocol::NO); diff --git a/src/include/network/postgres_wire_protocol.h b/src/include/network/postgres_wire_protocol.h index be01f337db9..3f1ee66246d 100644 --- a/src/include/network/postgres_wire_protocol.h +++ b/src/include/network/postgres_wire_protocol.h @@ -16,60 +16,149 @@ #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) +#define MAKE_COMMAND(type) \ + std::static_pointer_cast( \ + std::make_shared(curr_input_packet_)) namespace peloton { namespace network { - class PostgresWireProtocol : public WireProtocol { public: - Transition Process(ReadBuffer &in, - WriteBuffer &out, - size_t thread_id) override; - private: - bool startup_ = true; - PostgresRawInputPacket curr_input_packet_; - - PostgresNetworkCommand PacketToCommand() { - if (startup_) { - int32_t proto_version = curr_input_packet_.GetInt(4); - LOG_INFO("protocol version: %d", proto_version); - if (proto_version == SSL_MESSAGE_VERNO) { - return SslInitCommand(curr_input_packet_); - } - } + // TODO(Tianyu): Remove tcop when tcop refactor complete + PostgresWireProtocol(tcop::TrafficCop *tcop) : tcop_(tcop) {} + + inline Transition Process(std::shared_ptr &in, + WriteQueue &out, + size_t thread_id) override { + if (!BuildPacket(in)) return Transition::NEED_READ; + std::shared_ptr command = PacketToCommand(); curr_input_packet_.Clear(); + return command->Exec(*this, out, thread_id); + } + + inline void AddCommandLineOption(std::string name, std::string val) { + cmdline_options_[name] = val; + } + + inline void FinishStartup() { startup_ = false; } + + std::shared_ptr PacketToCommand() { + if (startup_) return MAKE_COMMAND(StartupCommand); + switch (curr_input_packet_.msg_type_) { + case NetworkMessageType::SIMPLE_QUERY_COMMAND: + return MAKE_COMMAND(SimpleQueryCommand); + case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); + case NetworkMessageType::BIND_COMMAND:return MAKE_COMMAND(BindCommand); + case NetworkMessageType::DESCRIBE_COMMAND: + return MAKE_COMMAND(DescribeCommand); + case NetworkMessageType::EXECUTE_COMMAND: + return MAKE_COMMAND(ExecuteCommand); + case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); + case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); + case NetworkMessageType::TERMINATE_COMMAND: + return MAKE_COMMAND(TerminateCommand); + case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); + default: + throw NetworkProcessException("Unexpected Packet Type: " + + std::to_string(static_cast(curr_input_packet_.msg_type_))); + } } + // TODO(Tianyu): Remove this when tcop refactor complete + tcop::TrafficCop *tcop_; + private: + bool startup_ = true; + PostgresInputPacket curr_input_packet_{}; + std::unordered_map cmdline_options_; - bool BuildPacket(ReadBuffer &in) { + bool BuildPacket(std::shared_ptr &in) { if (!ReadPacketHeader(in)) return false; - if (!in.HasMore(curr_input_packet_.len_)) return false; - in.Read(curr_input_packet_.len_, - curr_input_packet_.packet_head_, - curr_input_packet_.packet_tail_); + + size_t size_needed = curr_input_packet_.extended_ + ? curr_input_packet_.len_ + - curr_input_packet_.buf_->BytesAvailable() + : curr_input_packet_.len_; + if (!in->HasMore(size_needed)) return false; + + if (curr_input_packet_.extended_) + curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); return true; } - bool ReadPacketHeader(ReadBuffer &in) { + bool ReadPacketHeader(std::shared_ptr &in) { if (curr_input_packet_.header_parsed_) return true; + // Header format: 1 byte message type (only if non-startup) // + 4 byte message size (inclusive of these 4 bytes) size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); // Make sure the entire header is readable - if (!in.HasMore(header_size)) return false; + if (!in->HasMore(header_size)) return false; // The header is ready to be read, fill in fields accordingly if (!startup_) - curr_input_packet_.msg_type_ = in.ReadValue(); - curr_input_packet_.len_ = ntohl(in.ReadValue()) - sizeof(int32_t); + curr_input_packet_.msg_type_ = in->ReadRawValue(); + curr_input_packet_.len_ = in->ReadInt(sizeof(int32_t)) - sizeof(int32_t); + // Extend the buffer as needed - if (curr_input_packet_.len_ > in.Capacity()) { + if (curr_input_packet_.len_ > in->Capacity()) { LOG_INFO("Extended Buffer size required for packet of size %ld", curr_input_packet_.len_); - in.ExpandTo(curr_input_packet_.len_); + // Allocate a larger buffer and copy bytes off from the I/O layer's buffer + curr_input_packet_.buf_ = + std::make_shared(curr_input_packet_.len_); + curr_input_packet_.extended_ = true; + } else { + curr_input_packet_.buf_ = in; } + curr_input_packet_.header_parsed_ = true; return true; } }; -} // namespace peloton -} // namespace network \ No newline at end of file + +class PostgresWireUtilities { + public: + PostgresWireUtilities() = delete; + + static inline void SendErrorResponse( + WriteQueue &out, + std::vector> error_status) { + out.BeginPacket(NetworkMessageType::ERROR_RESPONSE); + for (const auto &entry : error_status) { + out.AppendRawValue(entry.first); + out.AppendString(entry.second); + } + // Nul-terminate packet + out.AppendRawValue(0) + .EndPacket(); + } + + static inline void SendStartupResponse(WriteQueue &out) { + // auth-ok + out.BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST).EndPacket(); + + // parameter status map + for (auto &entry : parameter_status_map_) + out.BeginPacket(NetworkMessageType::PARAMETER_STATUS) + .AppendString(entry.first) + .AppendString(entry.second) + .EndPacket(); + + // ready-for-query + SendReadyForQuery(NetworkTransactionStateType::IDLE, out); + } + + static inline void SendReadyForQuery(NetworkTransactionStateType txn_status, + WriteQueue &out) { + out.BeginPacket(NetworkMessageType::READY_FOR_QUERY) + .AppendRawValue(txn_status) + .EndPacket(); + } + + private: + // TODO(Tianyu): It looks broken that this never changes. + // TODO(Tianyu): Also, Initialize. + static const std::unordered_map + parameter_status_map_; +}; +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/wire_protocol.h b/src/include/network/wire_protocol.h index 2629d7caced..1e2b3e64ae1 100644 --- a/src/include/network/wire_protocol.h +++ b/src/include/network/wire_protocol.h @@ -19,8 +19,8 @@ namespace network { class WireProtocol { public: // TODO(Tianyu): What the hell is this thread_id thingy - virtual Transition Process(ReadBuffer &in, - WriteBuffer &out, + virtual Transition Process(std::shared_ptr &in, + WriteQueue &out, size_t thread_id) = 0; }; diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index e87eabd74c3..f94d690ddda 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -161,8 +161,7 @@ void ConnectionHandle::StateMachine::Accept(Transition action, next = result.second(connection); } catch (NetworkProcessException &e) { LOG_ERROR("%s\n", e.what()); - connection.TryCloseConnection(); - return; + next = Transition::TERMINATE; } } } @@ -227,7 +226,7 @@ Transition ConnectionHandle::TrySslHandshake() { auto write_ret = TryWrite(); if (write_ret != Transition::PROCEED) return write_ret; } - return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake( + return NetworkIoWrapperFactory::GetInstance().TryUseSsl( io_wrapper_); } diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index 2c675ea0d63..88efad09216 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -40,7 +40,7 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( return reused_wrapper; } -Transition NetworkIoWrapperFactory::PerformSslHandshake( +Transition NetworkIoWrapperFactory::TryUseSsl( std::shared_ptr &io_wrapper) { SSL *context; if (!io_wrapper->SslAble()) { diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index 80bad466c0c..b914fd06051 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -29,21 +29,21 @@ Transition NetworkIoWrapper::WritePacket(OutputPacket *pkt) { return result; } - wbuf_->Append(static_cast(pkt->msg_type)); + wbuf_->AppendRaw(static_cast(pkt->msg_type)); if (!pkt->single_type_pkt) // Need to convert bytes to network order - wbuf_->Append(htonl(pkt->len + sizeof(int32_t))); + wbuf_->AppendRaw(htonl(pkt->len + sizeof(int32_t))); pkt->skip_header_write = true; } // Write Packet Content for (size_t len = pkt->len; len != 0;) { if (wbuf_->HasSpaceFor(len)) { - wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, len); + wbuf_->AppendRaw(std::begin(pkt->buf) + pkt->write_ptr, len); break; } else { auto write_size = wbuf_->RemainingCapacity(); - wbuf_->Append(std::begin(pkt->buf) + pkt->write_ptr, write_size); + wbuf_->AppendRaw(std::begin(pkt->buf) + pkt->write_ptr, write_size); len -= write_size; pkt->write_ptr += write_size; auto result = FlushWriteBuffer(); diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp new file mode 100644 index 00000000000..c1d3928a909 --- /dev/null +++ b/src/network/postgres_network_commands.cpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_network_commands.cpp +// +// Identification: src/network/postgres_network_commands.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#include "parser/postgresparser.h" +#include "network/postgres_wire_protocol.h" +#include "network/peloton_server.h" +#include "network/postgres_network_commands.h" + +#define SSL_MESSAGE_VERNO 80877103 +#define PROTO_MAJOR_VERSION(x) ((x) >> 16) + +namespace peloton { +namespace network { + +Transition StartupCommand::Exec(PostgresWireProtocol &protocol_object, + WriteQueue &out, + size_t) { + // Always flush startup response + out.ForceFlush(); + int32_t proto_version = input_packet_.buf_->ReadInt(sizeof(int32_t)); + LOG_INFO("protocol version: %d", proto_version); + if (proto_version == SSL_MESSAGE_VERNO) { + // SSL Handshake initialization + // TODO(Tianyu): This static method probably needs to be moved into + // settings manager + bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE); + out.WriteSingleBytePacket(ssl_able + ? NetworkMessageType::SSL_YES + : NetworkMessageType::SSL_NO); + return ssl_able ? Transition::NEED_SSL_HANDSHAKE : Transition::PROCEED; + } else { + // Normal Initialization + if (PROTO_MAJOR_VERSION(proto_version) != 3) { + // Only protocol version 3 is supported + LOG_ERROR("Protocol error: Only protocol version 3 is supported."); + PostgresWireUtilities::SendErrorResponse( + out, {{NetworkMessageType::HUMAN_READABLE_ERROR, + "Protocol Version Not Support"}}); + return Transition::TERMINATE; + } + + std::string token, value; + // TODO(Yuchen): check for more malformed cases + // Read out startup package info + while (input_packet_.buf_->HasMore()) { + token = input_packet_.buf_->ReadString(); + LOG_TRACE("Option key is %s", token.c_str()); + // TODO(Tianyu): Why does this commented out line need to be here? + // if (!input_packet_.buf_->HasMore()) break; + value = input_packet_.buf_->ReadString(); + LOG_TRACE("Option value is %s", value.c_str()); + // TODO(Tianyu): We never seem to use this crap? + protocol_object.AddCommandLineOption(token, value); + // TODO(Tianyu): Do this after we are done refactoring traffic cop +// if (token.compare("database") == 0) { +// traffic_cop_->SetDefaultDatabaseName(value); +// } + } + + // Startup Response, for now we do not do any authentication + PostgresWireUtilities::SendStartupResponse(out); + protocol_object.FinishStartup(); + return Transition::PROCEED; + } +} + +Transition SimpleQueryCommand::Exec(PostgresWireProtocol &protocol_object, + WriteQueue &out, + size_t thread_id) { + out.ForceFlush(); + std::string query = input_packet_.buf_->ReadString(input_packet_.len_); + LOG_TRACE("Execute query: %s", query.c_str()); + std::unique_ptr sql_stmt_list; + try { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + sql_stmt_list = peloton_parser.BuildParseTree(query); + // When the query is empty(such as ";" or ";;", still valid), + // the pare tree is empty, parser will return nullptr. + if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) + throw ParserException("Error Parsing SQL statement"); + } catch (Exception &e) { + protocol_object.tcop_->ProcessInvalidStatement(); + std::string error_message = e.what(); + PostgresWireUtilities::SendErrorResponse( + out, {{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); + PostgresWireUtilities::SendReadyForQuery() + + } +} + +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index 6f03a617667..aa61fece3aa 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -929,12 +929,12 @@ bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, size_t header_size = startup ? sizeof(int32_t) : sizeof(int32_t) + 1; // check if header bytes are available if (!rbuf.HasMore(header_size)) return false; - if (!startup) rpkt.msg_type = rbuf.ReadValue(); + if (!startup) rpkt.msg_type = rbuf.ReadRawValue(); // get packet size from the header // extract packet contents size // content lengths should exclude the length bytes - rpkt.len = ntohl(rbuf.ReadValue()) - sizeof(uint32_t); + rpkt.len = rbuf.ReadInt(sizeof(int32_t)) - sizeof(uint32_t); // do we need to use the extended buffer for this packet? rpkt.is_extended = (rpkt.len > rbuf.Capacity()); From c69f5d6dc9df90511bf1a98fe1b25906c30ea196 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 22 Jun 2018 15:21:20 -0400 Subject: [PATCH 11/48] Start infrastructure for new protocol layer --- src/codegen/util/buffered_io.cpp | 100 +++++ src/include/network/buffered_io.h | 353 +++++++++++++++++ src/include/network/marshal.h | 365 +----------------- .../network/postgres_network_commands.h | 39 +- .../network/postgres_protocol_interpreter.h | 284 ++++++++++++++ src/include/network/postgres_wire_protocol.h | 164 -------- src/include/network/protocol_interpreter.h | 30 ++ src/include/traffic_cop/tcop.h | 22 ++ src/network/buffered_io.cpp | 51 +++ src/network/postgres_network_commands.cpp | 27 +- 10 files changed, 861 insertions(+), 574 deletions(-) create mode 100644 src/codegen/util/buffered_io.cpp create mode 100644 src/include/network/buffered_io.h create mode 100644 src/include/network/postgres_protocol_interpreter.h delete mode 100644 src/include/network/postgres_wire_protocol.h create mode 100644 src/include/network/protocol_interpreter.h create mode 100644 src/include/traffic_cop/tcop.h create mode 100644 src/network/buffered_io.cpp diff --git a/src/codegen/util/buffered_io.cpp b/src/codegen/util/buffered_io.cpp new file mode 100644 index 00000000000..8fc1e567b00 --- /dev/null +++ b/src/codegen/util/buffered_io.cpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// buffer.cpp +// +// Identification: src/codegen/util/buffer.cpp +// +// Copyright (c) 2015-2017, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "codegen/util/buffer.h" + +#include + +#include "common/logger.h" +#include "common/timer.h" +#include "storage/backend_manager.h" + +namespace peloton { +namespace codegen { +namespace util { + +Buffer::Buffer() + : buffer_start_(nullptr), buffer_pos_(nullptr), buffer_end_(nullptr) { + auto &backend_manager = storage::BackendManager::GetInstance(); + buffer_start_ = reinterpret_cast( + backend_manager.Allocate(BackendType::MM, kInitialBufferSize)); + buffer_pos_ = buffer_start_; + buffer_end_ = buffer_start_ + kInitialBufferSize; + + LOG_DEBUG("Initialized buffer with size %.2lf KB", + kInitialBufferSize / 1024.0); +} + +Buffer::~Buffer() { + if (buffer_start_ != nullptr) { + LOG_DEBUG("Releasing %.2lf KB of memory", AllocatedSpace() / 1024.0); + auto &backend_manager = storage::BackendManager::GetInstance(); + backend_manager.Release(BackendType::MM, buffer_start_); + } + buffer_start_ = buffer_pos_ = buffer_end_ = nullptr; +} + +void Buffer::Init(Buffer &buffer) { new (&buffer) Buffer(); } + +void Buffer::Destroy(Buffer &buffer) { buffer.~Buffer(); } + +char *Buffer::Append(uint32_t num_bytes) { + MakeRoomForBytes(num_bytes); + char *ret = buffer_pos_; + buffer_pos_ += num_bytes; + return ret; +} + +void Buffer::Reset() { buffer_pos_ = buffer_start_; } + +void Buffer::MakeRoomForBytes(uint64_t num_bytes) { + bool has_room = + (buffer_start_ != nullptr && buffer_pos_ + num_bytes < buffer_end_); + if (has_room) { + return; + } + + // Need to allocate some space + uint64_t curr_alloc_size = AllocatedSpace(); + uint64_t curr_used_size = UsedSpace(); + + // Ensure the current size is a power of two + PELOTON_ASSERT(curr_alloc_size % 2 == 0); + + // Allocate double the buffer room + uint64_t next_alloc_size = curr_alloc_size; + do { + next_alloc_size *= 2; + } while (next_alloc_size < num_bytes); + LOG_DEBUG("Resizing buffer from %.2lf bytes to %.2lf KB ...", + curr_alloc_size / 1024.0, next_alloc_size / 1024.0); + + auto &backend_manager = storage::BackendManager::GetInstance(); + auto *new_buffer = reinterpret_cast( + backend_manager.Allocate(BackendType::MM, next_alloc_size)); + + // Now copy the previous buffer into the new area + PELOTON_MEMCPY(new_buffer, buffer_start_, curr_used_size); + + // Set pointers + char *old_buffer_start = buffer_start_; + buffer_start_ = new_buffer; + buffer_pos_ = buffer_start_ + curr_used_size; + buffer_end_ = buffer_start_ + next_alloc_size; + + // Release old buffer + backend_manager.Release(BackendType::MM, old_buffer_start); +} + +} // namespace util +} // namespace codegen +} // namespace peloton diff --git a/src/include/network/buffered_io.h b/src/include/network/buffered_io.h new file mode 100644 index 00000000000..3aff3535f6c --- /dev/null +++ b/src/include/network/buffered_io.h @@ -0,0 +1,353 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// buffered_io.h +// +// Identification: src/include/network/buffered_io.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include + +#include +#include +#include "common/internal_types.h" + +namespace peloton { +namespace network { +/** + * A plain old buffer with a movable cursor, the meaning of which is dependent + * on the use case. + * + * The buffer has a fix capacity and one can write a variable amount of + * meaningful bytes into it. We call this amount "size" of the buffer. + */ +class Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline Buffer(size_t capacity) : capacity_(capacity) { + buf_.reserve(capacity); + } + + /** + * Reset the buffer pointer and clears content + */ + inline void Reset() { + size_ = 0; + offset_ = 0; + } + + /** + * @param bytes The amount of bytes to check between the cursor and the end + * of the buffer (defaults to any) + * @return Whether there is any more bytes between the cursor and + * the end of the buffer + */ + inline bool HasMore(size_t bytes = 1) { return offset_ + bytes <= size_; } + + /** + * @return Whether the buffer is at capacity. (All usable space is filled + * with meaningful bytes) + */ + inline bool Full() { return size_ == Capacity(); } + + /** + * @return Iterator to the beginning of the buffer + */ + inline ByteBuf::const_iterator Begin() { return std::begin(buf_); } + + /** + * @return Capacity of the buffer (not actual size) + */ + inline size_t Capacity() const { return capacity_; } + + /** + * Shift contents to align the current cursor with start of the buffer, + * remove all bytes before the cursor. + */ + inline void MoveContentToHead() { + auto unprocessed_len = size_ - offset_; + std::memmove(&buf_[0], &buf_[offset_], unprocessed_len); + size_ = unprocessed_len; + offset_ = 0; + } + + // TODO(Tianyu): Fix this after protocol refactor +// protected: + size_t size_ = 0, offset_ = 0, capacity_; + ByteBuf buf_; + private: + friend class WriteQueue; +}; + +/** + * A buffer specialize for read + */ +class ReadBuffer : public Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} + /** + * Read as many bytes as possible using SSL read + * @param context SSL context to read from + * @return the return value of ssl read + */ + inline int FillBufferFrom(SSL *context) { + ERR_clear_error(); + ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); + int err = SSL_get_error(context, bytes_read); + if (err == SSL_ERROR_NONE) size_ += bytes_read; + return err; + }; + + /** + * Read as many bytes as possible using Posix from an fd + * @param fd the file descriptor to read from + * @return the return value of posix read + */ + inline int FillBufferFrom(int fd) { + ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); + if (bytes_read > 0) size_ += bytes_read; + return (int) bytes_read; + } + + /** + * Read the specified amount of bytes off from another read buffer. The bytes + * will be consumed (cursor moved) on the other buffer and appended to the end + * of this buffer + * @param other The other buffer to read from + * @param size Number of bytes to read + */ + inline void FillBufferFrom(ReadBuffer &other, size_t size) { + other.Read(size, &buf_[size_]); + size_ += size; + } + + /** + * The number of bytes available to be consumed (i.e. meaningful bytes after + * current read cursor) + * @return The number of bytes available to be consumed + */ + inline size_t BytesAvailable() { return size_ - offset_; } + + /** + * Read the given number of bytes into destination, advancing cursor by that + * number. It is up to the caller to ensure that there are enough bytes + * available in the read buffer at this point. + * @param bytes Number of bytes to read + * @param dest Desired memory location to read into + */ + inline void Read(size_t bytes, void *dest) { + std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, + reinterpret_cast(dest)); + offset_ += bytes; + } + + /** + * Read an integer of specified length off of the read buffer (1, 2, + * or 4 bytes). It is assumed that the bytes in the buffer are in network + * byte ordering and will be converted to the correct host ordering. It is up + * to the caller to ensure that there are enough bytes available in the read + * buffer at this point. + * @param len Length of the integer, either 1, 2, or 4 bytes. + * @return value of integer switched from network byte order + */ + int ReadInt(uint8_t len); + + /** + * Read a block of bytes off the read buffer as a string. + * @param len Length of the string, inclusive of nul-terminator + * @return string of specified length at head of read buffer + */ + std::string ReadString(size_t len); + + /** + * Read a nul-terminated string off the read buffer, or throw an exception + * if no nul-terminator is found within packet range. + * @return string at head of read buffer + */ + std::string ReadString(); + + /** + * Read a value of type T off of the buffer, advancing cursor by appropriate + * amount. Does NOT convert from network bytes order. It is the caller's + * responsibility to do so if needed. + * @tparam T type of value to read off. Preferably a primitive type. + * @return the value of type T + */ + template + inline T ReadRawValue() { + T result; + Read(sizeof(result), &result); + return result; + } +}; + +/** + * A buffer specialized for write + */ +class WriteBuffer : public Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline WriteBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} + + /** + * Write as many bytes as possible using SSL write + * @param context SSL context to write out to + * @return return value of SSL write + */ + inline int WriteOutTo(SSL *context) { + ERR_clear_error(); + ssize_t bytes_written = SSL_write(context, &buf_[offset_], size_ - offset_); + int err = SSL_get_error(context, bytes_written); + if (err == SSL_ERROR_NONE) offset_ += bytes_written; + return err; + } + + /** + * Write as many bytes as possible using Posix write to fd + * @param fd File descriptor to write out to + * @return return value of Posix write + */ + inline int WriteOutTo(int fd) { + ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); + if (bytes_written > 0) offset_ += bytes_written; + return (int) bytes_written; + } + + /** + * The remaining capacity of this buffer. This value is equal to the + * maximum capacity minus the capacity already in use. + * @return Remaining capacity + */ + inline size_t RemainingCapacity() { return Capacity() - size_; } + + /** + * @param bytes Desired number of bytes to write + * @return Whether the buffer can accommodate the number of bytes given + */ + inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } + + /** + * Append the desired range into current buffer. + * @param src beginning of range + * @param len length of range, in bytes + */ + inline void AppendRaw(const void *src, size_t len) { + if (len == 0) return; + auto bytes_src = reinterpret_cast(src); + std::copy(bytes_src, bytes_src + len, std::begin(buf_) + size_); + size_ += len; + } + + /** + * Append the given value into the current buffer. Does NOT convert to + * network byte order. It is up to the caller to do so. + * @tparam T input type + * @param val value to write into buffer + */ + template + inline void AppendRaw(T val) { + AppendRaw(&val, sizeof(T)); + } +}; + +/** + * A WriteQueue is a series of WriteBuffers that can buffer an uncapped amount + * of writes without the need to copy and resize. + * + * It is expected that a specific protocol will wrap this to expose a better + * API for protocol-specific behavior. + */ +class WriteQueue { + public: + /** + * Instantiates a new WriteQueue. By default this holds one buffer. + */ + inline WriteQueue() { + Reset(); + } + + /** + * Reset the write queue to its default state. + */ + inline void Reset() { + buffers_.resize(1); + flush_ = false; + if (buffers_[0] == nullptr) + buffers_[0] = std::make_shared(); + else + buffers_[0]->Reset(); + } + + /** + * Force this WriteQueue to be flushed next time the network layer + * is available to do so. + */ + inline void ForceFlush() { flush_ = true; } + + /** + * Whether this WriteQueue should be flushed out to network or not. + * A WriteQueue should be flushed either when the first buffer is full + * or when manually set to do so (e.g. when the client is waiting for + * a small response) + * @return whether we should flush this write queue + */ + inline bool ShouldFlush() { return flush_ || buffers_.size() > 1; } + + /** + * Write len many bytes starting from src into the write queue, allocating + * a new buffer if need be. The write is split up between two buffers + * if breakup is set to true (which is by default) + * @param src write head + * @param len number of bytes to write + * @param breakup whether to split write into two buffers if need be. + */ + void BufferWriteRaw(const void *src, size_t len, bool breakup = true) { + WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); + if (tail.HasSpaceFor(len)) + tail.AppendRaw(src, len); + else { + // Only write partially if we are allowed to + size_t written = breakup ? tail.RemainingCapacity() : 0; + tail.AppendRaw(src, written); + buffers_.push_back(std::make_shared()); + BufferWriteRaw(reinterpret_cast(src) + written, len - written); + } + } + + /** + * Write val into the write queue, allocating a new buffer if need be. + * The write is split up between two buffers if breakup is set to true + * (which is by default). No conversion of byte ordering is performed. It is + * up to the caller to do so if needed. + * @tparam T type of value to write + * @param val value to write + * @param breakup whether to split write into two buffers if need be. + */ + template + inline void BufferWriteRawValue(T val, bool breakup = true) { + BufferWriteRaw(&val, sizeof(T), breakup); + } + + private: + friend class PostgresPacketWriter; + std::vector> buffers_; + bool flush_ = false; +}; + +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index bee00461039..5fc42ad3f74 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -12,374 +12,11 @@ #pragma once -#include -#include - -#include -#include -#include "common/internal_types.h" -#include "common/logger.h" -#include "common/macros.h" -#include "common/exception.h" -#include "network/network_types.h" -#include "network/postgres_network_commands.h" +#include "network/buffered_io.h" #define BUFFER_INIT_SIZE 100 namespace peloton { namespace network { -/** - * A plain old buffer with a movable cursor, the meaning of which is dependent - * on the use case. - * - * The buffer has a fix capacity and one can write a variable amount of - * meaningful bytes into it. We call this amount "size" of the buffer. - */ -class Buffer { - public: - /** - * Instantiates a new buffer and reserve default many bytes. - */ - inline Buffer(size_t capacity) : capacity_(capacity) { - buf_.reserve(capacity); - } - - /** - * Reset the buffer pointer and clears content - */ - inline void Reset() { - size_ = 0; - offset_ = 0; - } - - /** - * @param bytes The amount of bytes to check between the cursor and the end - * of the buffer (defaults to any) - * @return Whether there is any more bytes between the cursor and - * the end of the buffer - */ - inline bool HasMore(size_t bytes = 1) { return offset_ + bytes <= size_; } - - /** - * @return Whether the buffer is at capacity. (All usable space is filled - * with meaningful bytes) - */ - inline bool Full() { return size_ == Capacity(); } - - /** - * @return Iterator to the beginning of the buffer - */ - inline ByteBuf::const_iterator Begin() { return std::begin(buf_); } - - /** - * @return Capacity of the buffer (not actual size) - */ - inline size_t Capacity() const { return capacity_; } - - /** - * Shift contents to align the current cursor with start of the buffer, - * remove all bytes before the cursor. - */ - inline void MoveContentToHead() { - auto unprocessed_len = size_ - offset_; - std::memmove(&buf_[0], &buf_[offset_], unprocessed_len); - size_ = unprocessed_len; - offset_ = 0; - } - - // TODO(Tianyu): Fix this after protocol refactor -// protected: - size_t size_ = 0, offset_ = 0, capacity_; - ByteBuf buf_; - private: - friend class WriteQueue; -}; - -/** - * A buffer specialize for read - */ -class ReadBuffer : public Buffer { - public: - inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) - : Buffer(capacity) {} - /** - * Read as many bytes as possible using SSL read - * @param context SSL context to read from - * @return the return value of ssl read - */ - inline int FillBufferFrom(SSL *context) { - ERR_clear_error(); - ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); - int err = SSL_get_error(context, bytes_read); - if (err == SSL_ERROR_NONE) size_ += bytes_read; - return err; - }; - - /** - * Read as many bytes as possible using Posix from an fd - * @param fd the file descriptor to read from - * @return the return value of posix read - */ - inline int FillBufferFrom(int fd) { - ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); - if (bytes_read > 0) size_ += bytes_read; - return (int) bytes_read; - } - - inline void FillBufferFrom(ReadBuffer &other, size_t size) { - other.Read(size, &buf_[size_]); - size_ += size; - } - - /** - * The number of bytes available to be consumed (i.e. meaningful bytes after - * current read cursor) - * @return The number of bytes available to be consumed - */ - inline size_t BytesAvailable() { return size_ - offset_; } - - /** - * Read the given number of bytes into destination, advancing cursor by that - * number - * @param bytes Number of bytes to read - * @param dest Desired memory location to read into - */ - inline void Read(size_t bytes, void *dest) { - std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, - reinterpret_cast(dest)); - offset_ += bytes; - } - - inline int ReadInt(uint8_t len) { - switch (len) { - case 1:return ReadRawValue(); - case 2:return ntohs(ReadRawValue()); - case 4:return ntohl(ReadRawValue()); - default: - throw NetworkProcessException( - "Error when de-serializing: Invalid int size"); - } - } - - // Inclusive of nul-terminator - inline std::string ReadString(size_t len) { - if (len == 0) return ""; - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + offset_ + (len - 1)); - offset_ += len; - return result; - } - - // Read until nul terminator - inline std::string ReadString() { - // search for the nul terminator - for (size_t i = offset_; i < size_; i++) { - if (buf_[i] == 0) { - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + i); - // +1 because we want to skip nul - offset_ = i + 1; - return result; - } - } - // No nul terminator found - throw NetworkProcessException("Expected nil in read buffer, none found"); - } - - /** - * Read a value of type T off of the buffer, advancing cursor by appropriate - * amount. Does NOT convert from network bytes order. It is the caller's - * responsibility to do so. - * @tparam T type of value to read off. Preferably a primitive type - * @return the value of type T - */ - template - inline T ReadRawValue() { - T result; - Read(sizeof(result), &result); - return result; - } -}; - -/** - * A buffer specialized for write - */ -class WriteBuffer : public Buffer { - public: - inline WriteBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) - : Buffer(capacity) {} - /** - * Write as many bytes as possible using SSL write - * @param context SSL context to write out to - * @return return value of SSL write - */ - inline int WriteOutTo(SSL *context) { - ERR_clear_error(); - ssize_t bytes_written = SSL_write(context, &buf_[offset_], size_ - offset_); - int err = SSL_get_error(context, bytes_written); - if (err == SSL_ERROR_NONE) offset_ += bytes_written; - return err; - } - - /** - * Write as many bytes as possible using Posix write to fd - * @param fd File descriptor to write out to - * @return return value of Posix write - */ - inline int WriteOutTo(int fd) { - ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); - if (bytes_written > 0) offset_ += bytes_written; - return (int) bytes_written; - } - - /** - * The remaining capacity of this buffer. This value is equal to the - * maximum capacity minus the capacity already in use. - * @return Remaining capacity - */ - inline size_t RemainingCapacity() { return Capacity() - size_; } - - /** - * @param bytes Desired number of bytes to write - * @return Whether the buffer can accommodate the number of bytes given - */ - inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } - - /** - * Append the desired range into current buffer. - * @param src beginning of range - * @param len length of range, in bytes - */ - inline void AppendRaw(const void *src, size_t len) { - if (len == 0) return; - auto bytes_src = reinterpret_cast(src); - std::copy(bytes_src, bytes_src + len, std::begin(buf_) + size_); - size_ += len; - } - - /** - * Append the given value into the current buffer. Does NOT convert to - * network byte order. It is up to the caller to do so. - * @tparam T input type - * @param val value to write into buffer - */ - template - inline void AppendRaw(T val) { - AppendRaw(&val, sizeof(T)); - } -}; - -class WriteQueue { - friend class NetworkIoWrapper; - public: - inline WriteQueue() { - Reset(); - } - - inline void Reset() { - buffers_.resize(1); - flush_ = false; - if (buffers_[0] == nullptr) - buffers_[0] = std::make_shared(); - else - buffers_[0]->Reset(); - } - - inline void WriteSingleBytePacket(NetworkMessageType type) { - // No active packet being constructed - PELOTON_ASSERT(curr_packet_len_ == nullptr); - BufferWriteRawValue(type); - } - - inline WriteQueue &BeginPacket(NetworkMessageType type) { - // No active packet being constructed - PELOTON_ASSERT(curr_packet_len_ == nullptr); - BufferWriteRawValue(type); - // Remember the size field since we will need to modify it as we go along. - // It is important that our size field is contiguous and not broken between - // two buffers. - BufferWriteRawValue(0, false); - WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); - curr_packet_len_ = - reinterpret_cast(&tail.buf_[tail.size_ - sizeof(int32_t)]); - return *this; - } - - inline WriteQueue &AppendRaw(const void *src, size_t len) { - BufferWriteRaw(src, len); - // Add the size field to the len of the packet. Be mindful of byte - // ordering. We switch to network ordering only when the packet is finished - *curr_packet_len_ += len; - return *this; - } - - template - inline WriteQueue &AppendRawValue(T val) { - return AppendRaw(&val, sizeof(T)); - } - - inline WriteQueue &AppendInt(uint8_t len, uint32_t val) { - int32_t result; - switch (len) { - case 1: - result = val; - break; - case 2: - result = htons(val); - break; - case 4: - result = htonl(val); - break; - default: - throw NetworkProcessException("Error constructing packet: invalid int size"); - } - return AppendRaw(&result, len); - } - - inline WriteQueue &AppendString(const std::string &str, bool nul_terminate = true) { - return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); - } - - inline void EndPacket() { - PELOTON_ASSERT(curr_packet_len_ != nullptr); - // Switch to network byte ordering, add the 4 bytes of size field - *curr_packet_len_ = htonl(*curr_packet_len_ + sizeof(int32_t)); - curr_packet_len_ = nullptr; - } - - inline WriteQueue &ForceFlush() { - flush_ = true; - return *this; - } - - inline bool ShouldFlush() { return flush_ || buffers_.size() > 1; } - - private: - - void BufferWriteRaw(const void *src, size_t len, bool breakup = true) { - WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); - if (tail.HasSpaceFor(len)) - tail.AppendRaw(src, len); - else { - // Only write partially if we are allowed to - size_t written = breakup ? tail.RemainingCapacity() : 0; - tail.AppendRaw(src, written); - buffers_.push_back(std::make_shared()); - BufferWriteRaw(reinterpret_cast(src) + written, len - written); - } - } - - template - inline void BufferWriteRawValue(T val, bool breakup = true) { - BufferWriteRaw(&val, sizeof(T), breakup); - } - - std::vector> buffers_; - bool flush_ = false; - // In network byte order. - uint32_t *curr_packet_len_ = nullptr; - -}; - class InputPacket { public: NetworkMessageType msg_type; // header diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 23094d3acd5..f99276a3ae2 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -17,18 +17,18 @@ #include "traffic_cop/traffic_cop.h" #include "network/marshal.h" -#define DEFINE_COMMAND(name, protocol_type) \ +#define DEFINE_COMMAND(name) \ class name : public PostgresNetworkCommand { \ public: \ - explicit name(PostgresRawInputPacket &input_packet) \ - : PostgresNetworkCommand(std::move(input_packet), protocol_type) {} \ - virtual Transition Exec(PostgresWireProtocol &, WriteQueue &, size_t) override; \ + explicit name(PostgresInputPacket &input_packet) \ + : PostgresNetworkCommand(std::move(input_packet)) {} \ + virtual Transition Exec(PostgresProtocolInterpreter &, WriteQueue &, size_t) override; \ } namespace peloton { namespace network { -class PostgresWireProtocol; +class PostgresProtocolInterpreter; struct PostgresInputPacket { NetworkMessageType msg_type_ = NetworkMessageType::NULL_COMMAND; @@ -50,30 +50,27 @@ struct PostgresInputPacket { class PostgresNetworkCommand { public: - virtual Transition Exec(PostgresWireProtocol &protocol_obj, + virtual Transition Exec(PostgresProtocolInterpreter &protocol_obj, WriteQueue &out, size_t thread_id) = 0; protected: - PostgresNetworkCommand(PostgresInputPacket input_packet, - ResponseProtocol response_protocol) - : input_packet_(input_packet), - response_protocol_(response_protocol) {} + PostgresNetworkCommand(PostgresInputPacket input_packet) + : input_packet_(input_packet) {} PostgresInputPacket input_packet_; - const ResponseProtocol response_protocol_; }; // TODO(Tianyu): Fix response types -DEFINE_COMMAND(StartupCommand, ResponseProtocol::SIMPLE); -DEFINE_COMMAND(SimpleQueryCommand, ResponseProtocol::SIMPLE); -DEFINE_COMMAND(ParseCommand, ResponseProtocol::NO); -DEFINE_COMMAND(BindCommand, ResponseProtocol::NO); -DEFINE_COMMAND(DescribeCommand, ResponseProtocol::NO); -DEFINE_COMMAND(ExecuteCommand, ResponseProtocol::EXTENDED); -DEFINE_COMMAND(SyncCommand, ResponseProtocol::SIMPLE); -DEFINE_COMMAND(CloseCommand, ResponseProtocol::NO); -DEFINE_COMMAND(TerminateCommand, ResponseProtocol::NO); -DEFINE_COMMAND(NullCommand, ResponseProtocol::NO); +DEFINE_COMMAND(StartupCommand); +DEFINE_COMMAND(SimpleQueryCommand); +DEFINE_COMMAND(ParseCommand); +DEFINE_COMMAND(BindCommand); +DEFINE_COMMAND(DescribeCommand); +DEFINE_COMMAND(ExecuteCommand); +DEFINE_COMMAND(SyncCommand); +DEFINE_COMMAND(CloseCommand); +DEFINE_COMMAND(TerminateCommand); +DEFINE_COMMAND(NullCommand); } // namespace network } // namespace peloton diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h new file mode 100644 index 00000000000..7d03f6996d0 --- /dev/null +++ b/src/include/network/postgres_protocol_interpreter.h @@ -0,0 +1,284 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_wire_protocol.h +// +// Identification: src/include/network/postgres_wire_protocol.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once +#include "common/logger.h" +#include "network/protocol_interpreter.h" +#include "network/postgres_network_commands.h" +#include "network/buffered_io.h" + +#define SSL_MESSAGE_VERNO 80877103 +#define PROTO_MAJOR_VERSION(x) ((x) >> 16) +#define MAKE_COMMAND(type) \ + std::static_pointer_cast( \ + std::make_shared(curr_input_packet_)) + +namespace peloton { +namespace network { + +class PostgresProtocolInterpreter : public ProtocolInterpreter { + public: + inline Transition Process(std::shared_ptr &in, + WriteQueue &out, + size_t thread_id) override { + if (!BuildPacket(in)) return Transition::NEED_READ; + std::shared_ptr command = PacketToCommand(); + curr_input_packet_.Clear(); + return command->Exec(*this, out, thread_id); + } + + inline void AddCommandLineOption(std::string name, std::string val) { + cmdline_options_[name] = val; + } + + inline void FinishStartup() { startup_ = false; } + + std::shared_ptr PacketToCommand() { + if (startup_) return MAKE_COMMAND(StartupCommand); + switch (curr_input_packet_.msg_type_) { + case NetworkMessageType::SIMPLE_QUERY_COMMAND: + return MAKE_COMMAND(SimpleQueryCommand); + case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); + case NetworkMessageType::BIND_COMMAND + :return MAKE_COMMAND(BindCommand); + case NetworkMessageType::DESCRIBE_COMMAND: + return MAKE_COMMAND(DescribeCommand); + case NetworkMessageType::EXECUTE_COMMAND: + return MAKE_COMMAND(ExecuteCommand); + case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); + case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); + case NetworkMessageType::TERMINATE_COMMAND: + return MAKE_COMMAND(TerminateCommand); + case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); + default: + throw NetworkProcessException("Unexpected Packet Type: " + + std::to_string(static_cast(curr_input_packet_.msg_type_))); + } + } + + private: + bool startup_ = true; + PostgresInputPacket curr_input_packet_{}; + std::unordered_map cmdline_options_; + + bool BuildPacket(std::shared_ptr &in) { + if (!ReadPacketHeader(in)) return false; + + size_t size_needed = curr_input_packet_.extended_ + ? curr_input_packet_.len_ + - curr_input_packet_.buf_->BytesAvailable() + : curr_input_packet_.len_; + if (!in->HasMore(size_needed)) return false; + + if (curr_input_packet_.extended_) + curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); + return true; + } + + bool ReadPacketHeader(std::shared_ptr &in) { + if (curr_input_packet_.header_parsed_) return true; + + // Header format: 1 byte message type (only if non-startup) + // + 4 byte message size (inclusive of these 4 bytes) + size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); + // Make sure the entire header is readable + if (!in->HasMore(header_size)) return false; + + // The header is ready to be read, fill in fields accordingly + if (!startup_) + curr_input_packet_.msg_type_ = in->ReadRawValue(); + curr_input_packet_.len_ = in->ReadInt(sizeof(int32_t)) - sizeof(int32_t); + + // Extend the buffer as needed + if (curr_input_packet_.len_ > in->Capacity()) { + LOG_INFO("Extended Buffer size required for packet of size %ld", + curr_input_packet_.len_); + // Allocate a larger buffer and copy bytes off from the I/O layer's buffer + curr_input_packet_.buf_ = + std::make_shared(curr_input_packet_.len_); + curr_input_packet_.extended_ = true; + } else { + curr_input_packet_.buf_ = in; + } + + curr_input_packet_.header_parsed_ = true; + return true; + } +}; + +/** + * Wrapper around an I/O layer WriteQueue to provide Postgres-sprcific + * helper methods. + */ +class PostgresPacketWriter { + public: + /* + * Instantiates a new PostgresPacketWriter backed by the given WriteQueue + */ + PostgresPacketWriter(WriteQueue &write_queue) : queue_(write_queue) {} + + ~PostgresPacketWriter() { + // Make sure no packet is being written on destruction, otherwise we are + // malformed write buffer + PELOTON_ASSERT(curr_packet_len_ == nullptr); + } + + /** + * Write out a packet with a single byte (e.g. SSL_YES or SSL_NO). This is a + * special case since no size field is provided. + * @param type Type of message to write out + */ + inline void WriteSingleBytePacket(NetworkMessageType type) { + // Make sure no active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + queue_.BufferWriteRawValue(type); + } + + /** + * Begin writing a new packet. Caller can use other + * @param type + * @return self-reference for chaining + */ + PostgresPacketWriter &BeginPacket(NetworkMessageType type) { + // No active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + queue_.BufferWriteRawValue(type); + // Remember the size field since we will need to modify it as we go along. + // It is important that our size field is contiguous and not broken between + // two buffers. + queue_.BufferWriteRawValue(0, false); + WriteBuffer &tail = *(queue_.buffers_[queue_.buffers_.size() - 1]); + curr_packet_len_ = + reinterpret_cast(&tail.buf_[tail.size_ - sizeof(int32_t)]); + return *this; + } + + inline PostgresPacketWriter &AppendRaw(const void *src, size_t len) { + queue_.BufferWriteRaw(src, len); + // Add the size field to the len of the packet. Be mindful of byte + // ordering. We switch to network ordering only when the packet is finished + *curr_packet_len_ += len; + return *this; + } + + template + inline PostgresPacketWriter &AppendRawValue(T val) { + return AppendRaw(&val, sizeof(T)); + } + + PostgresPacketWriter &AppendInt(uint8_t len, uint32_t val) { + int32_t result; + switch (len) { + case 1:result = val; + break; + case 2:result = htons(val); + break; + case 4:result = htonl(val); + break; + default: + throw NetworkProcessException( + "Error constructing packet: invalid int size"); + } + return AppendRaw(&result, len); + } + + inline PostgresPacketWriter &AppendString(const std::string &str, + bool nul_terminate = true) { + return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); + } + + inline void EndPacket() { + PELOTON_ASSERT(curr_packet_len_ != nullptr); + // Switch to network byte ordering, add the 4 bytes of size field + *curr_packet_len_ = htonl(*curr_packet_len_ + sizeof(int32_t)); + curr_packet_len_ = nullptr; + } + private: + // We need to keep track of the size field of the current packet, + // so we can update it as more bytes are written into this packet. + uint32_t *curr_packet_len_ = nullptr; + // Underlying WriteQueue backing this writer + WriteQueue &queue_; +}; + +class PostgresWireUtilities { + public: + PostgresWireUtilities() = delete; + + static inline void SendErrorResponse( + PostgresPacketWriter &writer, + std::vector> error_status) { + writer.BeginPacket(NetworkMessageType::ERROR_RESPONSE); + for (const auto &entry : error_status) { + writer.AppendRawValue(entry.first); + writer.AppendString(entry.second); + } + // Nul-terminate packet + writer.AppendRawValue(0) + .EndPacket(); + } + + static inline void SendStartupResponse(PostgresPacketWriter &writer) { + // auth-ok + writer.BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST).EndPacket(); + + // parameter status map + for (auto &entry : parameter_status_map_) + writer.BeginPacket(NetworkMessageType::PARAMETER_STATUS) + .AppendString(entry.first) + .AppendString(entry.second) + .EndPacket(); + + // ready-for-query + SendReadyForQuery(writer, NetworkTransactionStateType::IDLE); + } + + static inline void SendReadyForQuery(PostgresPacketWriter &writer, + NetworkTransactionStateType txn_status) { + writer.BeginPacket(NetworkMessageType::READY_FOR_QUERY) + .AppendRawValue(txn_status) + .EndPacket(); + } + + static inline void SendEmptyQueryResponse(PostgresPacketWriter &writer) { + writer.BeginPacket(NetworkMessageType::EMPTY_QUERY_RESPONSE).EndPacket(); + } + + static inline void SendCommandCompleteResponse(PostgresPacketWriter &writer, + const QueryType &query_type, + int rows) { + std::string tag = QueryTypeToString(query_type); + switch (query_type) { + case QueryType::QUERY_INSERT:tag += " 0 " + std::to_string(rows); + break; + case QueryType::QUERY_BEGIN: + case QueryType::QUERY_COMMIT: + case QueryType::QUERY_ROLLBACK: + case QueryType::QUERY_CREATE_TABLE: + case QueryType::QUERY_CREATE_DB: + case QueryType::QUERY_CREATE_INDEX: + case QueryType::QUERY_CREATE_TRIGGER: + case QueryType::QUERY_PREPARE:break; + default:tag += " " + std::to_string(rows); + } + writer.BeginPacket(NetworkMessageType::COMMAND_COMPLETE) + .AppendString(tag) + .EndPacket(); + } + + private: + // TODO(Tianyu): It looks broken that this never changes. + // TODO(Tianyu): Also, Initialize. + static const std::unordered_map + parameter_status_map_; +}; +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/postgres_wire_protocol.h b/src/include/network/postgres_wire_protocol.h deleted file mode 100644 index 3f1ee66246d..00000000000 --- a/src/include/network/postgres_wire_protocol.h +++ /dev/null @@ -1,164 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// postgres_wire_protocol.h -// -// Identification: src/include/network/postgres_wire_protocol.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// -#pragma once -#include "common/logger.h" -#include "network/wire_protocol.h" -#include "network/postgres_network_commands.h" - -#define SSL_MESSAGE_VERNO 80877103 -#define PROTO_MAJOR_VERSION(x) ((x) >> 16) -#define MAKE_COMMAND(type) \ - std::static_pointer_cast( \ - std::make_shared(curr_input_packet_)) - -namespace peloton { -namespace network { -class PostgresWireProtocol : public WireProtocol { - public: - // TODO(Tianyu): Remove tcop when tcop refactor complete - PostgresWireProtocol(tcop::TrafficCop *tcop) : tcop_(tcop) {} - - inline Transition Process(std::shared_ptr &in, - WriteQueue &out, - size_t thread_id) override { - if (!BuildPacket(in)) return Transition::NEED_READ; - std::shared_ptr command = PacketToCommand(); - curr_input_packet_.Clear(); - return command->Exec(*this, out, thread_id); - } - - inline void AddCommandLineOption(std::string name, std::string val) { - cmdline_options_[name] = val; - } - - inline void FinishStartup() { startup_ = false; } - - std::shared_ptr PacketToCommand() { - if (startup_) return MAKE_COMMAND(StartupCommand); - switch (curr_input_packet_.msg_type_) { - case NetworkMessageType::SIMPLE_QUERY_COMMAND: - return MAKE_COMMAND(SimpleQueryCommand); - case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); - case NetworkMessageType::BIND_COMMAND:return MAKE_COMMAND(BindCommand); - case NetworkMessageType::DESCRIBE_COMMAND: - return MAKE_COMMAND(DescribeCommand); - case NetworkMessageType::EXECUTE_COMMAND: - return MAKE_COMMAND(ExecuteCommand); - case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); - case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); - case NetworkMessageType::TERMINATE_COMMAND: - return MAKE_COMMAND(TerminateCommand); - case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); - default: - throw NetworkProcessException("Unexpected Packet Type: " + - std::to_string(static_cast(curr_input_packet_.msg_type_))); - } - } - // TODO(Tianyu): Remove this when tcop refactor complete - tcop::TrafficCop *tcop_; - private: - bool startup_ = true; - PostgresInputPacket curr_input_packet_{}; - std::unordered_map cmdline_options_; - - bool BuildPacket(std::shared_ptr &in) { - if (!ReadPacketHeader(in)) return false; - - size_t size_needed = curr_input_packet_.extended_ - ? curr_input_packet_.len_ - - curr_input_packet_.buf_->BytesAvailable() - : curr_input_packet_.len_; - if (!in->HasMore(size_needed)) return false; - - if (curr_input_packet_.extended_) - curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); - return true; - } - - bool ReadPacketHeader(std::shared_ptr &in) { - if (curr_input_packet_.header_parsed_) return true; - - // Header format: 1 byte message type (only if non-startup) - // + 4 byte message size (inclusive of these 4 bytes) - size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); - // Make sure the entire header is readable - if (!in->HasMore(header_size)) return false; - - // The header is ready to be read, fill in fields accordingly - if (!startup_) - curr_input_packet_.msg_type_ = in->ReadRawValue(); - curr_input_packet_.len_ = in->ReadInt(sizeof(int32_t)) - sizeof(int32_t); - - // Extend the buffer as needed - if (curr_input_packet_.len_ > in->Capacity()) { - LOG_INFO("Extended Buffer size required for packet of size %ld", - curr_input_packet_.len_); - // Allocate a larger buffer and copy bytes off from the I/O layer's buffer - curr_input_packet_.buf_ = - std::make_shared(curr_input_packet_.len_); - curr_input_packet_.extended_ = true; - } else { - curr_input_packet_.buf_ = in; - } - - curr_input_packet_.header_parsed_ = true; - return true; - } -}; - -class PostgresWireUtilities { - public: - PostgresWireUtilities() = delete; - - static inline void SendErrorResponse( - WriteQueue &out, - std::vector> error_status) { - out.BeginPacket(NetworkMessageType::ERROR_RESPONSE); - for (const auto &entry : error_status) { - out.AppendRawValue(entry.first); - out.AppendString(entry.second); - } - // Nul-terminate packet - out.AppendRawValue(0) - .EndPacket(); - } - - static inline void SendStartupResponse(WriteQueue &out) { - // auth-ok - out.BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST).EndPacket(); - - // parameter status map - for (auto &entry : parameter_status_map_) - out.BeginPacket(NetworkMessageType::PARAMETER_STATUS) - .AppendString(entry.first) - .AppendString(entry.second) - .EndPacket(); - - // ready-for-query - SendReadyForQuery(NetworkTransactionStateType::IDLE, out); - } - - static inline void SendReadyForQuery(NetworkTransactionStateType txn_status, - WriteQueue &out) { - out.BeginPacket(NetworkMessageType::READY_FOR_QUERY) - .AppendRawValue(txn_status) - .EndPacket(); - } - - private: - // TODO(Tianyu): It looks broken that this never changes. - // TODO(Tianyu): Also, Initialize. - static const std::unordered_map - parameter_status_map_; -}; -} // namespace network -} // namespace peloton \ No newline at end of file diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h new file mode 100644 index 00000000000..8194ee0a29f --- /dev/null +++ b/src/include/network/protocol_interpreter.h @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// protocol_interpreter.h +// +// Identification: src/include/network/protocol_interpreter.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once +#include +#include "network/network_types.h" +#include "network/buffered_io.h" + +namespace peloton { +namespace network { + +class ProtocolInterpreter { + public: + // TODO(Tianyu): What the hell is this thread_id thingy + virtual Transition Process(std::shared_ptr &in, + WriteQueue &out, + size_t thread_id) = 0; + +}; + +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h new file mode 100644 index 00000000000..7e4a2346f9d --- /dev/null +++ b/src/include/traffic_cop/tcop.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// t=cop.h +// +// Identification: src/include/traffic_cop/tcop.h +// +// Copyright (c) 2015-18, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// +#pragma once + +namespace peloton { +namespace tcop { +struct JobHandle { + bool is_queueing_; +}; + + +} // namespace tcop +} // namespace peloton \ No newline at end of file diff --git a/src/network/buffered_io.cpp b/src/network/buffered_io.cpp new file mode 100644 index 00000000000..04f779d46e3 --- /dev/null +++ b/src/network/buffered_io.cpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// buffered_io.cpp +// +// Identification: src/network/buffered_io.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "network/buffered_io.h" +#include "common/exception.h" + +namespace peloton { +namespace network { +int ReadBuffer::ReadInt(uint8_t len) { + switch (len) { + case 1: return ReadRawValue(); + case 2: return ntohs(ReadRawValue()); + case 4: return ntohl(ReadRawValue()); + default: throw NetworkProcessException( + "Error when de-serializing: Invalid int size"); + } +} + +std::string ReadBuffer::ReadString(size_t len) { + if (len == 0) throw NetworkProcessException("Unexpected string size: 0"); + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + offset_ + (len - 1)); + offset_ += len; + return result; +} + +std::string ReadBuffer::ReadString() { + // search for the nul terminator + for (size_t i = offset_; i < size_; i++) { + if (buf_[i] == 0) { + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + i); + // +1 because we want to skip nul + offset_ = i + 1; + return result; + } + } + // No nul terminator found + throw NetworkProcessException("Expected nil in read buffer, none found"); +} +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index c1d3928a909..24a88b885b5 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// #include "parser/postgresparser.h" -#include "network/postgres_wire_protocol.h" +#include "network/postgres_protocol_interpreter.h" #include "network/peloton_server.h" #include "network/postgres_network_commands.h" @@ -20,7 +20,7 @@ namespace peloton { namespace network { -Transition StartupCommand::Exec(PostgresWireProtocol &protocol_object, +Transition StartupCommand::Exec(PostgresProtocolInterpreter &protocol_object, WriteQueue &out, size_t) { // Always flush startup response @@ -72,29 +72,6 @@ Transition StartupCommand::Exec(PostgresWireProtocol &protocol_object, } } -Transition SimpleQueryCommand::Exec(PostgresWireProtocol &protocol_object, - WriteQueue &out, - size_t thread_id) { - out.ForceFlush(); - std::string query = input_packet_.buf_->ReadString(input_packet_.len_); - LOG_TRACE("Execute query: %s", query.c_str()); - std::unique_ptr sql_stmt_list; - try { - auto &peloton_parser = parser::PostgresParser::GetInstance(); - sql_stmt_list = peloton_parser.BuildParseTree(query); - // When the query is empty(such as ";" or ";;", still valid), - // the pare tree is empty, parser will return nullptr. - if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) - throw ParserException("Error Parsing SQL statement"); - } catch (Exception &e) { - protocol_object.tcop_->ProcessInvalidStatement(); - std::string error_message = e.what(); - PostgresWireUtilities::SendErrorResponse( - out, {{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); - PostgresWireUtilities::SendReadyForQuery() - - } -} } // namespace network } // namespace peloton \ No newline at end of file From 01d63a01fcc10ab3263a6a51a4e1151e936467e0 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 22 Jun 2018 16:48:58 -0400 Subject: [PATCH 12/48] Switch to templates for marshalling ints --- .../{buffered_io.cpp => netork_io_utils.cpp} | 0 src/include/network/marshal.h | 2 +- src/include/network/netork_io_utils.h | 397 ++++++++++++++++++ src/include/network/network_io_wrappers.h | 1 + .../network/postgres_network_commands.h | 25 +- .../network/postgres_protocol_interpreter.h | 110 +++-- src/include/network/protocol_interpreter.h | 2 +- src/include/network/wire_protocol.h | 29 -- src/network/buffered_io.cpp | 51 --- src/network/postgres_network_commands.cpp | 102 ++--- src/network/postgres_protocol_handler.cpp | 2 +- 11 files changed, 541 insertions(+), 180 deletions(-) rename src/codegen/util/{buffered_io.cpp => netork_io_utils.cpp} (100%) create mode 100644 src/include/network/netork_io_utils.h delete mode 100644 src/include/network/wire_protocol.h delete mode 100644 src/network/buffered_io.cpp diff --git a/src/codegen/util/buffered_io.cpp b/src/codegen/util/netork_io_utils.cpp similarity index 100% rename from src/codegen/util/buffered_io.cpp rename to src/codegen/util/netork_io_utils.cpp diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 5fc42ad3f74..8b255ae98fa 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -12,7 +12,7 @@ #pragma once -#include "network/buffered_io.h" +#include "network/netork_io_utils.h" #define BUFFER_INIT_SIZE 100 namespace peloton { diff --git a/src/include/network/netork_io_utils.h b/src/include/network/netork_io_utils.h new file mode 100644 index 00000000000..f953650c02a --- /dev/null +++ b/src/include/network/netork_io_utils.h @@ -0,0 +1,397 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// network_io_utils.h +// +// Identification: src/include/network/network_io_utils.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include + +#include +#include +#include "common/internal_types.h" +#include "common/exception.h" + +namespace peloton { +namespace network { +/** + * A plain old buffer with a movable cursor, the meaning of which is dependent + * on the use case. + * + * The buffer has a fix capacity and one can write a variable amount of + * meaningful bytes into it. We call this amount "size" of the buffer. + */ +class Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline Buffer(size_t capacity) : capacity_(capacity) { + buf_.reserve(capacity); + } + + /** + * Reset the buffer pointer and clears content + */ + inline void Reset() { + size_ = 0; + offset_ = 0; + } + + /** + * @param bytes The amount of bytes to check between the cursor and the end + * of the buffer (defaults to any) + * @return Whether there is any more bytes between the cursor and + * the end of the buffer + */ + inline bool HasMore(size_t bytes = 1) { return offset_ + bytes <= size_; } + + /** + * @return Whether the buffer is at capacity. (All usable space is filled + * with meaningful bytes) + */ + inline bool Full() { return size_ == Capacity(); } + + /** + * @return Iterator to the beginning of the buffer + */ + inline ByteBuf::const_iterator Begin() { return std::begin(buf_); } + + /** + * @return Capacity of the buffer (not actual size) + */ + inline size_t Capacity() const { return capacity_; } + + /** + * Shift contents to align the current cursor with start of the buffer, + * remove all bytes before the cursor. + */ + inline void MoveContentToHead() { + auto unprocessed_len = size_ - offset_; + std::memmove(&buf_[0], &buf_[offset_], unprocessed_len); + size_ = unprocessed_len; + offset_ = 0; + } + + // TODO(Tianyu): Fix this after protocol refactor +// protected: + size_t size_ = 0, offset_ = 0, capacity_; + ByteBuf buf_; + private: + friend class WriteQueue; +}; + +/** + * A buffer specialize for read + */ +class ReadBuffer : public Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} + /** + * Read as many bytes as possible using SSL read + * @param context SSL context to read from + * @return the return value of ssl read + */ + inline int FillBufferFrom(SSL *context) { + ERR_clear_error(); + ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); + int err = SSL_get_error(context, bytes_read); + if (err == SSL_ERROR_NONE) size_ += bytes_read; + return err; + }; + + /** + * Read as many bytes as possible using Posix from an fd + * @param fd the file descriptor to read from + * @return the return value of posix read + */ + inline int FillBufferFrom(int fd) { + ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); + if (bytes_read > 0) size_ += bytes_read; + return (int) bytes_read; + } + + /** + * Read the specified amount of bytes off from another read buffer. The bytes + * will be consumed (cursor moved) on the other buffer and appended to the end + * of this buffer + * @param other The other buffer to read from + * @param size Number of bytes to read + */ + inline void FillBufferFrom(ReadBuffer &other, size_t size) { + other.Read(size, &buf_[size_]); + size_ += size; + } + + /** + * The number of bytes available to be consumed (i.e. meaningful bytes after + * current read cursor) + * @return The number of bytes available to be consumed + */ + inline size_t BytesAvailable() { return size_ - offset_; } + + /** + * Read the given number of bytes into destination, advancing cursor by that + * number. It is up to the caller to ensure that there are enough bytes + * available in the read buffer at this point. + * @param bytes Number of bytes to read + * @param dest Desired memory location to read into + */ + inline void Read(size_t bytes, void *dest) { + std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, + reinterpret_cast(dest)); + offset_ += bytes; + } + + /** + * Read an integer of specified length off of the read buffer (1, 2, + * 4, or 8 bytes). It is assumed that the bytes in the buffer are in network + * byte ordering and will be converted to the correct host ordering. It is up + * to the caller to ensure that there are enough bytes available in the read + * buffer at this point. + * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. + * @return value of integer switched from network byte order + */ + template + inline T ReadInt() { + // We only want to allow for certain type sizes to be used + // After the static assert, the compiler should be smart enough to throw + // away the other cases and only leave the relevant return statement. + static_assert(sizeof(T) == 1 + || sizeof(T) == 2 + || sizeof(T) == 4 + || sizeof(T) == 8, "Invalid size for integer"); + switch (sizeof(T)) { + case 1: return ReadRawValue(); + case 2: return ntohs(ReadRawValue()); + case 4: return ntohl(ReadRawValue()); + case 8: return ntohll(ReadRawValue()); + // Will never be here due to compiler optimization + default: throw NetworkProcessException(""); + } + } + + /** + * Read a block of bytes off the read buffer as a string. + * @param len Length of the string, inclusive of nul-terminator + * @return string of specified length at head of read buffer + */ + std::string ReadString(size_t len) { + if (len == 0) throw NetworkProcessException("Unexpected string size: 0"); + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + offset_ + (len - 1)); + offset_ += len; + return result; + } + + /** + * Read a nul-terminated string off the read buffer, or throw an exception + * if no nul-terminator is found within packet range. + * @return string at head of read buffer + */ + std::string ReadString() { + // search for the nul terminator + for (size_t i = offset_; i < size_; i++) { + if (buf_[i] == 0) { + auto result = std::string(buf_.begin() + offset_, + buf_.begin() + i); + // +1 because we want to skip nul + offset_ = i + 1; + return result; + } + } + // No nul terminator found + throw NetworkProcessException("Expected nil in read buffer, none found"); + } + + /** + * Read a value of type T off of the buffer, advancing cursor by appropriate + * amount. Does NOT convert from network bytes order. It is the caller's + * responsibility to do so if needed. + * @tparam T type of value to read off. Preferably a primitive type. + * @return the value of type T + */ + template + inline T ReadRawValue() { + T result; + Read(sizeof(result), &result); + return result; + } +}; + +/** + * A buffer specialized for write + */ +class WriteBuffer : public Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline WriteBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} + + /** + * Write as many bytes as possible using SSL write + * @param context SSL context to write out to + * @return return value of SSL write + */ + inline int WriteOutTo(SSL *context) { + ERR_clear_error(); + ssize_t bytes_written = SSL_write(context, &buf_[offset_], size_ - offset_); + int err = SSL_get_error(context, bytes_written); + if (err == SSL_ERROR_NONE) offset_ += bytes_written; + return err; + } + + /** + * Write as many bytes as possible using Posix write to fd + * @param fd File descriptor to write out to + * @return return value of Posix write + */ + inline int WriteOutTo(int fd) { + ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); + if (bytes_written > 0) offset_ += bytes_written; + return (int) bytes_written; + } + + /** + * The remaining capacity of this buffer. This value is equal to the + * maximum capacity minus the capacity already in use. + * @return Remaining capacity + */ + inline size_t RemainingCapacity() { return Capacity() - size_; } + + /** + * @param bytes Desired number of bytes to write + * @return Whether the buffer can accommodate the number of bytes given + */ + inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } + + /** + * Append the desired range into current buffer. + * @param src beginning of range + * @param len length of range, in bytes + */ + inline void AppendRaw(const void *src, size_t len) { + if (len == 0) return; + auto bytes_src = reinterpret_cast(src); + std::copy(bytes_src, bytes_src + len, std::begin(buf_) + size_); + size_ += len; + } + + // TODO(Tianyu): Just for io wrappers for now. Probably can remove later. + inline void AppendRaw(ByteBuf::const_iterator src, size_t len) { + if (len == 0) return; + std::copy(src, src + len, std::begin(buf_) + size_); + size_ += len; + } + + /** + * Append the given value into the current buffer. Does NOT convert to + * network byte order. It is up to the caller to do so. + * @tparam T input type + * @param val value to write into buffer + */ + template + inline void AppendRaw(T val) { + AppendRaw(&val, sizeof(T)); + } +}; + +/** + * A WriteQueue is a series of WriteBuffers that can buffer an uncapped amount + * of writes without the need to copy and resize. + * + * It is expected that a specific protocol will wrap this to expose a better + * API for protocol-specific behavior. + */ +class WriteQueue { + public: + /** + * Instantiates a new WriteQueue. By default this holds one buffer. + */ + inline WriteQueue() { + Reset(); + } + + /** + * Reset the write queue to its default state. + */ + inline void Reset() { + buffers_.resize(1); + flush_ = false; + if (buffers_[0] == nullptr) + buffers_[0] = std::make_shared(); + else + buffers_[0]->Reset(); + } + + /** + * Force this WriteQueue to be flushed next time the network layer + * is available to do so. + */ + inline void ForceFlush() { flush_ = true; } + + /** + * Whether this WriteQueue should be flushed out to network or not. + * A WriteQueue should be flushed either when the first buffer is full + * or when manually set to do so (e.g. when the client is waiting for + * a small response) + * @return whether we should flush this write queue + */ + inline bool ShouldFlush() { return flush_ || buffers_.size() > 1; } + + /** + * Write len many bytes starting from src into the write queue, allocating + * a new buffer if need be. The write is split up between two buffers + * if breakup is set to true (which is by default) + * @param src write head + * @param len number of bytes to write + * @param breakup whether to split write into two buffers if need be. + */ + void BufferWriteRaw(const void *src, size_t len, bool breakup = true) { + WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); + if (tail.HasSpaceFor(len)) + tail.AppendRaw(src, len); + else { + // Only write partially if we are allowed to + size_t written = breakup ? tail.RemainingCapacity() : 0; + tail.AppendRaw(src, written); + buffers_.push_back(std::make_shared()); + BufferWriteRaw(reinterpret_cast(src) + written, len - written); + } + } + + /** + * Write val into the write queue, allocating a new buffer if need be. + * The write is split up between two buffers if breakup is set to true + * (which is by default). No conversion of byte ordering is performed. It is + * up to the caller to do so if needed. + * @tparam T type of value to write + * @param val value to write + * @param breakup whether to split write into two buffers if need be. + */ + template + inline void BufferWriteRawValue(T val, bool breakup = true) { + BufferWriteRaw(&val, sizeof(T), breakup); + } + + private: + friend class PostgresPacketWriter; + std::vector> buffers_; + bool flush_ = false; +}; + +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 1b100475ffd..c091a1d37d5 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -17,6 +17,7 @@ #include #include "common/exception.h" #include "common/utility.h" +#include "network/network_types.h" #include "network/marshal.h" namespace peloton { diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index f99276a3ae2..c373d2e418e 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -20,9 +20,9 @@ #define DEFINE_COMMAND(name) \ class name : public PostgresNetworkCommand { \ public: \ - explicit name(PostgresInputPacket &input_packet) \ + explicit name(PostgresInputPacket &input_packet) \ : PostgresNetworkCommand(std::move(input_packet)) {} \ - virtual Transition Exec(PostgresProtocolInterpreter &, WriteQueue &, size_t) override; \ + virtual Transition Exec(PostgresProtocolInterpreter &, WriteQueue &, size_t) override {} \ } namespace peloton { @@ -60,17 +60,16 @@ class PostgresNetworkCommand { PostgresInputPacket input_packet_; }; -// TODO(Tianyu): Fix response types -DEFINE_COMMAND(StartupCommand); -DEFINE_COMMAND(SimpleQueryCommand); -DEFINE_COMMAND(ParseCommand); -DEFINE_COMMAND(BindCommand); -DEFINE_COMMAND(DescribeCommand); -DEFINE_COMMAND(ExecuteCommand); -DEFINE_COMMAND(SyncCommand); -DEFINE_COMMAND(CloseCommand); -DEFINE_COMMAND(TerminateCommand); -DEFINE_COMMAND(NullCommand); +//DEFINE_COMMAND(StartupCommand); +//DEFINE_COMMAND(SimpleQueryCommand); +//DEFINE_COMMAND(ParseCommand); +//DEFINE_COMMAND(BindCommand); +//DEFINE_COMMAND(DescribeCommand); +//DEFINE_COMMAND(ExecuteCommand); +//DEFINE_COMMAND(SyncCommand); +//DEFINE_COMMAND(CloseCommand); +//DEFINE_COMMAND(TerminateCommand); +//DEFINE_COMMAND(NullCommand); } // namespace network } // namespace peloton diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 7d03f6996d0..d5e476a41c3 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -10,10 +10,11 @@ // //===----------------------------------------------------------------------===// #pragma once +#include #include "common/logger.h" #include "network/protocol_interpreter.h" #include "network/postgres_network_commands.h" -#include "network/buffered_io.h" +#include "network/netork_io_utils.h" #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) @@ -35,29 +36,33 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { return command->Exec(*this, out, thread_id); } - inline void AddCommandLineOption(std::string name, std::string val) { - cmdline_options_[name] = val; + inline void AddCommandLineOption(const std::string &name, std::string val) { + cmdline_options_[name] = std::move(val); } inline void FinishStartup() { startup_ = false; } std::shared_ptr PacketToCommand() { - if (startup_) return MAKE_COMMAND(StartupCommand); +// if (startup_) return MAKE_COMMAND(StartupCommand); switch (curr_input_packet_.msg_type_) { - case NetworkMessageType::SIMPLE_QUERY_COMMAND: - return MAKE_COMMAND(SimpleQueryCommand); - case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); - case NetworkMessageType::BIND_COMMAND - :return MAKE_COMMAND(BindCommand); - case NetworkMessageType::DESCRIBE_COMMAND: - return MAKE_COMMAND(DescribeCommand); - case NetworkMessageType::EXECUTE_COMMAND: - return MAKE_COMMAND(ExecuteCommand); - case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); - case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); - case NetworkMessageType::TERMINATE_COMMAND: - return MAKE_COMMAND(TerminateCommand); - case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); +// case NetworkMessageType::SIMPLE_QUERY_COMMAND: +// return MAKE_COMMAND(SimpleQueryCommand); +// case NetworkMessageType::PARSE_COMMAND: +// return MAKE_COMMAND(ParseCommand); +// case NetworkMessageType::BIND_COMMAND: +// return MAKE_COMMAND(BindCommand); +// case NetworkMessageType::DESCRIBE_COMMAND: +// return MAKE_COMMAND(DescribeCommand); +// case NetworkMessageType::EXECUTE_COMMAND: +// return MAKE_COMMAND(ExecuteCommand); +// case NetworkMessageType::SYNC_COMMAND: +// return MAKE_COMMAND(SyncCommand); +// case NetworkMessageType::CLOSE_COMMAND: +// return MAKE_COMMAND(CloseCommand); +// case NetworkMessageType::TERMINATE_COMMAND: +// return MAKE_COMMAND(TerminateCommand); +// case NetworkMessageType::NULL_COMMAND: +// return MAKE_COMMAND(NullCommand); default: throw NetworkProcessException("Unexpected Packet Type: " + std::to_string(static_cast(curr_input_packet_.msg_type_))); @@ -95,7 +100,7 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { // The header is ready to be read, fill in fields accordingly if (!startup_) curr_input_packet_.msg_type_ = in->ReadRawValue(); - curr_input_packet_.len_ = in->ReadInt(sizeof(int32_t)) - sizeof(int32_t); + curr_input_packet_.len_ = in->ReadInt() - sizeof(uint32_t); // Extend the buffer as needed if (curr_input_packet_.len_ > in->Capacity()) { @@ -143,7 +148,9 @@ class PostgresPacketWriter { } /** - * Begin writing a new packet. Caller can use other + * Begin writing a new packet. Caller can use other append methods to write + * contents to the packet. An explicit call to end packet must be made to + * make these writes valid. * @param type * @return self-reference for chaining */ @@ -161,7 +168,15 @@ class PostgresPacketWriter { return *this; } + /** + * Append raw bytes from specified memory location into the write queue. + * There must be a packet active in the writer. + * @param src memory location to write from + * @param len number of bytes to write + * @return self-reference for chaining + */ inline PostgresPacketWriter &AppendRaw(const void *src, size_t len) { + PELOTON_ASSERT(curr_packet_len_ != nullptr); queue_.BufferWriteRaw(src, len); // Add the size field to the len of the packet. Be mindful of byte // ordering. We switch to network ordering only when the packet is finished @@ -169,32 +184,61 @@ class PostgresPacketWriter { return *this; } + /** + * Append a value onto the write queue. There must be a packet active in the + * writer. No byte order conversion is performed. It is up to the caller to + * do so if needed. + * @tparam T type of value to write + * @param val value to write + * @return self-reference for chaining + */ template inline PostgresPacketWriter &AppendRawValue(T val) { return AppendRaw(&val, sizeof(T)); } - PostgresPacketWriter &AppendInt(uint8_t len, uint32_t val) { - int32_t result; - switch (len) { - case 1:result = val; - break; - case 2:result = htons(val); - break; - case 4:result = htonl(val); - break; - default: - throw NetworkProcessException( - "Error constructing packet: invalid int size"); + /** + * Append an integer of specified length onto the write queue. (1, 2, 4, or 8 + * bytes). It is assumed that these bytes need to be converted to network + * byte ordering. + * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. + * @param val value to write + * @return self-reference for chaining + */ + template + PostgresPacketWriter &AppendInt(T val) { + // We only want to allow for certain type sizes to be used + // After the static assert, the compiler should be smart enough to throw + // away the other cases and only leave the relevant return statement. + static_assert(sizeof(T) == 1 + || sizeof(T) == 2 + || sizeof(T) == 4 + || sizeof(T) == 8, "Invalid size for integer"); + switch (sizeof(T)) { + case 1: return AppendRawValue(val); + case 2: return AppendRawValue(ntohs(val)); + case 4: return AppendRawValue(ntohl(val)); + case 8: return AppendRawValue(ntohll(val)); + // Will never be here due to compiler optimization + default: throw NetworkProcessException(""); } - return AppendRaw(&result, len); } + /** + * Append a string onto the write queue. + * @param str the string to append + * @param nul_terminate whether the nul terminaor should be written as well + * @return self-reference for chaining + */ inline PostgresPacketWriter &AppendString(const std::string &str, bool nul_terminate = true) { return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); } + /** + * End the packet. A packet write must be in progress and said write is not + * well-formed until this method is called. + */ inline void EndPacket() { PELOTON_ASSERT(curr_packet_len_ != nullptr); // Switch to network byte ordering, add the 4 bytes of size field diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 8194ee0a29f..76d71f608c7 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -12,7 +12,7 @@ #pragma once #include #include "network/network_types.h" -#include "network/buffered_io.h" +#include "network/netork_io_utils.h" namespace peloton { namespace network { diff --git a/src/include/network/wire_protocol.h b/src/include/network/wire_protocol.h deleted file mode 100644 index 1e2b3e64ae1..00000000000 --- a/src/include/network/wire_protocol.h +++ /dev/null @@ -1,29 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// wire_protocol.h -// -// Identification: src/include/network/wire_protocol.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// -#pragma once -#include -#include "network/marshal.h" - -namespace peloton { -namespace network { - -class WireProtocol { - public: - // TODO(Tianyu): What the hell is this thread_id thingy - virtual Transition Process(std::shared_ptr &in, - WriteQueue &out, - size_t thread_id) = 0; - -}; - -} // namespace network -} // namespace peloton \ No newline at end of file diff --git a/src/network/buffered_io.cpp b/src/network/buffered_io.cpp deleted file mode 100644 index 04f779d46e3..00000000000 --- a/src/network/buffered_io.cpp +++ /dev/null @@ -1,51 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// buffered_io.cpp -// -// Identification: src/network/buffered_io.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "network/buffered_io.h" -#include "common/exception.h" - -namespace peloton { -namespace network { -int ReadBuffer::ReadInt(uint8_t len) { - switch (len) { - case 1: return ReadRawValue(); - case 2: return ntohs(ReadRawValue()); - case 4: return ntohl(ReadRawValue()); - default: throw NetworkProcessException( - "Error when de-serializing: Invalid int size"); - } -} - -std::string ReadBuffer::ReadString(size_t len) { - if (len == 0) throw NetworkProcessException("Unexpected string size: 0"); - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + offset_ + (len - 1)); - offset_ += len; - return result; -} - -std::string ReadBuffer::ReadString() { - // search for the nul terminator - for (size_t i = offset_; i < size_; i++) { - if (buf_[i] == 0) { - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + i); - // +1 because we want to skip nul - offset_ = i + 1; - return result; - } - } - // No nul terminator found - throw NetworkProcessException("Expected nil in read buffer, none found"); -} -} // namespace network -} // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 24a88b885b5..8741231c945 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -20,57 +20,57 @@ namespace peloton { namespace network { -Transition StartupCommand::Exec(PostgresProtocolInterpreter &protocol_object, - WriteQueue &out, - size_t) { - // Always flush startup response - out.ForceFlush(); - int32_t proto_version = input_packet_.buf_->ReadInt(sizeof(int32_t)); - LOG_INFO("protocol version: %d", proto_version); - if (proto_version == SSL_MESSAGE_VERNO) { - // SSL Handshake initialization - // TODO(Tianyu): This static method probably needs to be moved into - // settings manager - bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE); - out.WriteSingleBytePacket(ssl_able - ? NetworkMessageType::SSL_YES - : NetworkMessageType::SSL_NO); - return ssl_able ? Transition::NEED_SSL_HANDSHAKE : Transition::PROCEED; - } else { - // Normal Initialization - if (PROTO_MAJOR_VERSION(proto_version) != 3) { - // Only protocol version 3 is supported - LOG_ERROR("Protocol error: Only protocol version 3 is supported."); - PostgresWireUtilities::SendErrorResponse( - out, {{NetworkMessageType::HUMAN_READABLE_ERROR, - "Protocol Version Not Support"}}); - return Transition::TERMINATE; - } - - std::string token, value; - // TODO(Yuchen): check for more malformed cases - // Read out startup package info - while (input_packet_.buf_->HasMore()) { - token = input_packet_.buf_->ReadString(); - LOG_TRACE("Option key is %s", token.c_str()); - // TODO(Tianyu): Why does this commented out line need to be here? - // if (!input_packet_.buf_->HasMore()) break; - value = input_packet_.buf_->ReadString(); - LOG_TRACE("Option value is %s", value.c_str()); - // TODO(Tianyu): We never seem to use this crap? - protocol_object.AddCommandLineOption(token, value); - // TODO(Tianyu): Do this after we are done refactoring traffic cop -// if (token.compare("database") == 0) { -// traffic_cop_->SetDefaultDatabaseName(value); -// } - } - - // Startup Response, for now we do not do any authentication - PostgresWireUtilities::SendStartupResponse(out); - protocol_object.FinishStartup(); - return Transition::PROCEED; - } -} +//Transition StartupCommand::Exec(PostgresProtocolInterpreter &protocol_object, +// WriteQueue &out, +// size_t) { +// // Always flush startup response +// out.ForceFlush(); +// int32_t proto_version = input_packet_.buf_->ReadInt(); +// LOG_INFO("protocol version: %d", proto_version); +// if (proto_version == SSL_MESSAGE_VERNO) { +// // SSL Handshake initialization +// // TODO(Tianyu): This static method probably needs to be moved into +// // settings manager +// bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE); +// out.WriteSingleBytePacket(ssl_able +// ? NetworkMessageType::SSL_YES +// : NetworkMessageType::SSL_NO); +// return ssl_able ? Transition::NEED_SSL_HANDSHAKE : Transition::PROCEED; +// } else { +// // Normal Initialization +// if (PROTO_MAJOR_VERSION(proto_version) != 3) { +// // Only protocol version 3 is supported +// LOG_ERROR("Protocol error: Only protocol version 3 is supported."); +// PostgresWireUtilities::SendErrorResponse( +// out, {{NetworkMessageType::HUMAN_READABLE_ERROR, +// "Protocol Version Not Support"}}); +// return Transition::TERMINATE; +// } +// +// std::string token, value; +// // TODO(Yuchen): check for more malformed cases +// // Read out startup package info +// while (input_packet_.buf_->HasMore()) { +// token = input_packet_.buf_->ReadString(); +// LOG_TRACE("Option key is %s", token.c_str()); +// // TODO(Tianyu): Why does this commented out line need to be here? +// // if (!input_packet_.buf_->HasMore()) break; +// value = input_packet_.buf_->ReadString(); +// LOG_TRACE("Option value is %s", value.c_str()); +// // TODO(Tianyu): We never seem to use this crap? +// protocol_object.AddCommandLineOption(token, value); +// // TODO(Tianyu): Do this after we are done refactoring traffic cop +//// if (token.compare("database") == 0) { +//// traffic_cop_->SetDefaultDatabaseName(value); +//// } +// } +// +// // Startup Response, for now we do not do any authentication +// PostgresWireUtilities::SendStartupResponse(out); +// protocol_object.FinishStartup(); +// return Transition::PROCEED; +// } +//} } // namespace network diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index aa61fece3aa..a4dde2a4468 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -934,7 +934,7 @@ bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, // get packet size from the header // extract packet contents size // content lengths should exclude the length bytes - rpkt.len = rbuf.ReadInt(sizeof(int32_t)) - sizeof(uint32_t); + rpkt.len = rbuf.ReadInt() - sizeof(uint32_t); // do we need to use the extended buffer for this packet? rpkt.is_extended = (rpkt.len > rbuf.Capacity()); From 4014dad73fc32561ccbc11854f07d45599d8305c Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Sat, 23 Jun 2018 16:58:50 -0400 Subject: [PATCH 13/48] Move things around a bit more. Make layer separation clear --- src/codegen/util/netork_io_utils.cpp | 100 ----- .../util/{buffer.cpp => network_io_utils.cpp} | 0 src/include/network/buffered_io.h | 353 ------------------ src/include/network/connection_handle.h | 7 +- src/include/network/marshal.h | 2 +- .../{netork_io_utils.h => network_io_utils.h} | 0 .../network/postgres_network_commands.h | 66 ++-- .../network/postgres_protocol_interpreter.h | 279 ++------------ src/include/network/postgres_protocol_utils.h | 211 +++++++++++ src/include/network/protocol_interpreter.h | 2 +- src/include/traffic_cop/tcop.h | 44 ++- src/network/connection_handle.cpp | 25 +- src/network/postgres_network_commands.cpp | 105 +++--- 13 files changed, 391 insertions(+), 803 deletions(-) delete mode 100644 src/codegen/util/netork_io_utils.cpp rename src/codegen/util/{buffer.cpp => network_io_utils.cpp} (100%) delete mode 100644 src/include/network/buffered_io.h rename src/include/network/{netork_io_utils.h => network_io_utils.h} (100%) create mode 100644 src/include/network/postgres_protocol_utils.h diff --git a/src/codegen/util/netork_io_utils.cpp b/src/codegen/util/netork_io_utils.cpp deleted file mode 100644 index 8fc1e567b00..00000000000 --- a/src/codegen/util/netork_io_utils.cpp +++ /dev/null @@ -1,100 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// buffer.cpp -// -// Identification: src/codegen/util/buffer.cpp -// -// Copyright (c) 2015-2017, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "codegen/util/buffer.h" - -#include - -#include "common/logger.h" -#include "common/timer.h" -#include "storage/backend_manager.h" - -namespace peloton { -namespace codegen { -namespace util { - -Buffer::Buffer() - : buffer_start_(nullptr), buffer_pos_(nullptr), buffer_end_(nullptr) { - auto &backend_manager = storage::BackendManager::GetInstance(); - buffer_start_ = reinterpret_cast( - backend_manager.Allocate(BackendType::MM, kInitialBufferSize)); - buffer_pos_ = buffer_start_; - buffer_end_ = buffer_start_ + kInitialBufferSize; - - LOG_DEBUG("Initialized buffer with size %.2lf KB", - kInitialBufferSize / 1024.0); -} - -Buffer::~Buffer() { - if (buffer_start_ != nullptr) { - LOG_DEBUG("Releasing %.2lf KB of memory", AllocatedSpace() / 1024.0); - auto &backend_manager = storage::BackendManager::GetInstance(); - backend_manager.Release(BackendType::MM, buffer_start_); - } - buffer_start_ = buffer_pos_ = buffer_end_ = nullptr; -} - -void Buffer::Init(Buffer &buffer) { new (&buffer) Buffer(); } - -void Buffer::Destroy(Buffer &buffer) { buffer.~Buffer(); } - -char *Buffer::Append(uint32_t num_bytes) { - MakeRoomForBytes(num_bytes); - char *ret = buffer_pos_; - buffer_pos_ += num_bytes; - return ret; -} - -void Buffer::Reset() { buffer_pos_ = buffer_start_; } - -void Buffer::MakeRoomForBytes(uint64_t num_bytes) { - bool has_room = - (buffer_start_ != nullptr && buffer_pos_ + num_bytes < buffer_end_); - if (has_room) { - return; - } - - // Need to allocate some space - uint64_t curr_alloc_size = AllocatedSpace(); - uint64_t curr_used_size = UsedSpace(); - - // Ensure the current size is a power of two - PELOTON_ASSERT(curr_alloc_size % 2 == 0); - - // Allocate double the buffer room - uint64_t next_alloc_size = curr_alloc_size; - do { - next_alloc_size *= 2; - } while (next_alloc_size < num_bytes); - LOG_DEBUG("Resizing buffer from %.2lf bytes to %.2lf KB ...", - curr_alloc_size / 1024.0, next_alloc_size / 1024.0); - - auto &backend_manager = storage::BackendManager::GetInstance(); - auto *new_buffer = reinterpret_cast( - backend_manager.Allocate(BackendType::MM, next_alloc_size)); - - // Now copy the previous buffer into the new area - PELOTON_MEMCPY(new_buffer, buffer_start_, curr_used_size); - - // Set pointers - char *old_buffer_start = buffer_start_; - buffer_start_ = new_buffer; - buffer_pos_ = buffer_start_ + curr_used_size; - buffer_end_ = buffer_start_ + next_alloc_size; - - // Release old buffer - backend_manager.Release(BackendType::MM, old_buffer_start); -} - -} // namespace util -} // namespace codegen -} // namespace peloton diff --git a/src/codegen/util/buffer.cpp b/src/codegen/util/network_io_utils.cpp similarity index 100% rename from src/codegen/util/buffer.cpp rename to src/codegen/util/network_io_utils.cpp diff --git a/src/include/network/buffered_io.h b/src/include/network/buffered_io.h deleted file mode 100644 index 3aff3535f6c..00000000000 --- a/src/include/network/buffered_io.h +++ /dev/null @@ -1,353 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// buffered_io.h -// -// Identification: src/include/network/buffered_io.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once -#include -#include - -#include -#include -#include "common/internal_types.h" - -namespace peloton { -namespace network { -/** - * A plain old buffer with a movable cursor, the meaning of which is dependent - * on the use case. - * - * The buffer has a fix capacity and one can write a variable amount of - * meaningful bytes into it. We call this amount "size" of the buffer. - */ -class Buffer { - public: - /** - * Instantiates a new buffer and reserve capacity many bytes. - */ - inline Buffer(size_t capacity) : capacity_(capacity) { - buf_.reserve(capacity); - } - - /** - * Reset the buffer pointer and clears content - */ - inline void Reset() { - size_ = 0; - offset_ = 0; - } - - /** - * @param bytes The amount of bytes to check between the cursor and the end - * of the buffer (defaults to any) - * @return Whether there is any more bytes between the cursor and - * the end of the buffer - */ - inline bool HasMore(size_t bytes = 1) { return offset_ + bytes <= size_; } - - /** - * @return Whether the buffer is at capacity. (All usable space is filled - * with meaningful bytes) - */ - inline bool Full() { return size_ == Capacity(); } - - /** - * @return Iterator to the beginning of the buffer - */ - inline ByteBuf::const_iterator Begin() { return std::begin(buf_); } - - /** - * @return Capacity of the buffer (not actual size) - */ - inline size_t Capacity() const { return capacity_; } - - /** - * Shift contents to align the current cursor with start of the buffer, - * remove all bytes before the cursor. - */ - inline void MoveContentToHead() { - auto unprocessed_len = size_ - offset_; - std::memmove(&buf_[0], &buf_[offset_], unprocessed_len); - size_ = unprocessed_len; - offset_ = 0; - } - - // TODO(Tianyu): Fix this after protocol refactor -// protected: - size_t size_ = 0, offset_ = 0, capacity_; - ByteBuf buf_; - private: - friend class WriteQueue; -}; - -/** - * A buffer specialize for read - */ -class ReadBuffer : public Buffer { - public: - /** - * Instantiates a new buffer and reserve capacity many bytes. - */ - inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) - : Buffer(capacity) {} - /** - * Read as many bytes as possible using SSL read - * @param context SSL context to read from - * @return the return value of ssl read - */ - inline int FillBufferFrom(SSL *context) { - ERR_clear_error(); - ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); - int err = SSL_get_error(context, bytes_read); - if (err == SSL_ERROR_NONE) size_ += bytes_read; - return err; - }; - - /** - * Read as many bytes as possible using Posix from an fd - * @param fd the file descriptor to read from - * @return the return value of posix read - */ - inline int FillBufferFrom(int fd) { - ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); - if (bytes_read > 0) size_ += bytes_read; - return (int) bytes_read; - } - - /** - * Read the specified amount of bytes off from another read buffer. The bytes - * will be consumed (cursor moved) on the other buffer and appended to the end - * of this buffer - * @param other The other buffer to read from - * @param size Number of bytes to read - */ - inline void FillBufferFrom(ReadBuffer &other, size_t size) { - other.Read(size, &buf_[size_]); - size_ += size; - } - - /** - * The number of bytes available to be consumed (i.e. meaningful bytes after - * current read cursor) - * @return The number of bytes available to be consumed - */ - inline size_t BytesAvailable() { return size_ - offset_; } - - /** - * Read the given number of bytes into destination, advancing cursor by that - * number. It is up to the caller to ensure that there are enough bytes - * available in the read buffer at this point. - * @param bytes Number of bytes to read - * @param dest Desired memory location to read into - */ - inline void Read(size_t bytes, void *dest) { - std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, - reinterpret_cast(dest)); - offset_ += bytes; - } - - /** - * Read an integer of specified length off of the read buffer (1, 2, - * or 4 bytes). It is assumed that the bytes in the buffer are in network - * byte ordering and will be converted to the correct host ordering. It is up - * to the caller to ensure that there are enough bytes available in the read - * buffer at this point. - * @param len Length of the integer, either 1, 2, or 4 bytes. - * @return value of integer switched from network byte order - */ - int ReadInt(uint8_t len); - - /** - * Read a block of bytes off the read buffer as a string. - * @param len Length of the string, inclusive of nul-terminator - * @return string of specified length at head of read buffer - */ - std::string ReadString(size_t len); - - /** - * Read a nul-terminated string off the read buffer, or throw an exception - * if no nul-terminator is found within packet range. - * @return string at head of read buffer - */ - std::string ReadString(); - - /** - * Read a value of type T off of the buffer, advancing cursor by appropriate - * amount. Does NOT convert from network bytes order. It is the caller's - * responsibility to do so if needed. - * @tparam T type of value to read off. Preferably a primitive type. - * @return the value of type T - */ - template - inline T ReadRawValue() { - T result; - Read(sizeof(result), &result); - return result; - } -}; - -/** - * A buffer specialized for write - */ -class WriteBuffer : public Buffer { - public: - /** - * Instantiates a new buffer and reserve capacity many bytes. - */ - inline WriteBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) - : Buffer(capacity) {} - - /** - * Write as many bytes as possible using SSL write - * @param context SSL context to write out to - * @return return value of SSL write - */ - inline int WriteOutTo(SSL *context) { - ERR_clear_error(); - ssize_t bytes_written = SSL_write(context, &buf_[offset_], size_ - offset_); - int err = SSL_get_error(context, bytes_written); - if (err == SSL_ERROR_NONE) offset_ += bytes_written; - return err; - } - - /** - * Write as many bytes as possible using Posix write to fd - * @param fd File descriptor to write out to - * @return return value of Posix write - */ - inline int WriteOutTo(int fd) { - ssize_t bytes_written = write(fd, &buf_[offset_], size_ - offset_); - if (bytes_written > 0) offset_ += bytes_written; - return (int) bytes_written; - } - - /** - * The remaining capacity of this buffer. This value is equal to the - * maximum capacity minus the capacity already in use. - * @return Remaining capacity - */ - inline size_t RemainingCapacity() { return Capacity() - size_; } - - /** - * @param bytes Desired number of bytes to write - * @return Whether the buffer can accommodate the number of bytes given - */ - inline bool HasSpaceFor(size_t bytes) { return RemainingCapacity() >= bytes; } - - /** - * Append the desired range into current buffer. - * @param src beginning of range - * @param len length of range, in bytes - */ - inline void AppendRaw(const void *src, size_t len) { - if (len == 0) return; - auto bytes_src = reinterpret_cast(src); - std::copy(bytes_src, bytes_src + len, std::begin(buf_) + size_); - size_ += len; - } - - /** - * Append the given value into the current buffer. Does NOT convert to - * network byte order. It is up to the caller to do so. - * @tparam T input type - * @param val value to write into buffer - */ - template - inline void AppendRaw(T val) { - AppendRaw(&val, sizeof(T)); - } -}; - -/** - * A WriteQueue is a series of WriteBuffers that can buffer an uncapped amount - * of writes without the need to copy and resize. - * - * It is expected that a specific protocol will wrap this to expose a better - * API for protocol-specific behavior. - */ -class WriteQueue { - public: - /** - * Instantiates a new WriteQueue. By default this holds one buffer. - */ - inline WriteQueue() { - Reset(); - } - - /** - * Reset the write queue to its default state. - */ - inline void Reset() { - buffers_.resize(1); - flush_ = false; - if (buffers_[0] == nullptr) - buffers_[0] = std::make_shared(); - else - buffers_[0]->Reset(); - } - - /** - * Force this WriteQueue to be flushed next time the network layer - * is available to do so. - */ - inline void ForceFlush() { flush_ = true; } - - /** - * Whether this WriteQueue should be flushed out to network or not. - * A WriteQueue should be flushed either when the first buffer is full - * or when manually set to do so (e.g. when the client is waiting for - * a small response) - * @return whether we should flush this write queue - */ - inline bool ShouldFlush() { return flush_ || buffers_.size() > 1; } - - /** - * Write len many bytes starting from src into the write queue, allocating - * a new buffer if need be. The write is split up between two buffers - * if breakup is set to true (which is by default) - * @param src write head - * @param len number of bytes to write - * @param breakup whether to split write into two buffers if need be. - */ - void BufferWriteRaw(const void *src, size_t len, bool breakup = true) { - WriteBuffer &tail = *(buffers_[buffers_.size() - 1]); - if (tail.HasSpaceFor(len)) - tail.AppendRaw(src, len); - else { - // Only write partially if we are allowed to - size_t written = breakup ? tail.RemainingCapacity() : 0; - tail.AppendRaw(src, written); - buffers_.push_back(std::make_shared()); - BufferWriteRaw(reinterpret_cast(src) + written, len - written); - } - } - - /** - * Write val into the write queue, allocating a new buffer if need be. - * The write is split up between two buffers if breakup is set to true - * (which is by default). No conversion of byte ordering is performed. It is - * up to the caller to do so if needed. - * @tparam T type of value to write - * @param val value to write - * @param breakup whether to split write into two buffers if need be. - */ - template - inline void BufferWriteRawValue(T val, bool breakup = true) { - BufferWriteRaw(&val, sizeof(T), breakup); - } - - private: - friend class PostgresPacketWriter; - std::vector> buffers_; - bool flush_ = false; -}; - -} // namespace network -} // namespace peloton \ No newline at end of file diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index f82c838ada5..08467b95ab3 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -43,9 +43,10 @@ namespace peloton { namespace network { /** - * @brief A ConnectionHandle encapsulates all information about a client - * connection for its entire duration. This includes a state machine and the - * necessary libevent infrastructure for a handler to work on this connection. + * A ConnectionHandle encapsulates all information we need to do IO about + * a client connection for its entire duration. This includes a state machine + * and the necessary libevent infrastructure for a handler to work on this + * connection. */ class ConnectionHandle { public: diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 8b255ae98fa..be207594af4 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -12,7 +12,7 @@ #pragma once -#include "network/netork_io_utils.h" +#include "network/network_io_utils.h" #define BUFFER_INIT_SIZE 100 namespace peloton { diff --git a/src/include/network/netork_io_utils.h b/src/include/network/network_io_utils.h similarity index 100% rename from src/include/network/netork_io_utils.h rename to src/include/network/network_io_utils.h diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index c373d2e418e..70fcce7bf05 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -10,19 +10,23 @@ // //===----------------------------------------------------------------------===// #pragma once +#include #include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" #include "network/network_types.h" #include "traffic_cop/traffic_cop.h" #include "network/marshal.h" +#include "network/postgres_protocol_utils.h" -#define DEFINE_COMMAND(name) \ -class name : public PostgresNetworkCommand { \ - public: \ - explicit name(PostgresInputPacket &input_packet) \ - : PostgresNetworkCommand(std::move(input_packet)) {} \ - virtual Transition Exec(PostgresProtocolInterpreter &, WriteQueue &, size_t) override {} \ +#define DEFINE_COMMAND(name, flush) \ +class name : public PostgresNetworkCommand { \ + public: \ + explicit name(std::shared_ptr in) \ + : PostgresNetworkCommand(std::move(in), flush) {} \ + virtual Transition Exec(PostgresProtocolInterpreter &, \ + PostgresPacketWriter &, \ + size_t) override; \ } namespace peloton { @@ -30,46 +34,32 @@ namespace network { class PostgresProtocolInterpreter; -struct PostgresInputPacket { - NetworkMessageType msg_type_ = NetworkMessageType::NULL_COMMAND; - size_t len_ = 0; - std::shared_ptr buf_; - bool header_parsed_ = false, extended_ = false; - - PostgresInputPacket() = default; - PostgresInputPacket(const PostgresInputPacket &) = default; - PostgresInputPacket(PostgresInputPacket &&) = default; - - inline void Clear() { - msg_type_ = NetworkMessageType::NULL_COMMAND; - len_ = 0; - buf_ = nullptr; - header_parsed_ = false; - } -}; - class PostgresNetworkCommand { public: virtual Transition Exec(PostgresProtocolInterpreter &protocol_obj, - WriteQueue &out, + PostgresPacketWriter &out, size_t thread_id) = 0; + + inline bool FlushOnComplete() { return flush_on_complete_; } protected: - PostgresNetworkCommand(PostgresInputPacket input_packet) - : input_packet_(input_packet) {} + explicit PostgresNetworkCommand(std::shared_ptr in, bool flush) + : in_(std::move(in)), flush_on_complete_(flush) {} - PostgresInputPacket input_packet_; + std::shared_ptr in_; + private: + bool flush_on_complete_; }; -//DEFINE_COMMAND(StartupCommand); -//DEFINE_COMMAND(SimpleQueryCommand); -//DEFINE_COMMAND(ParseCommand); -//DEFINE_COMMAND(BindCommand); -//DEFINE_COMMAND(DescribeCommand); -//DEFINE_COMMAND(ExecuteCommand); -//DEFINE_COMMAND(SyncCommand); -//DEFINE_COMMAND(CloseCommand); -//DEFINE_COMMAND(TerminateCommand); -//DEFINE_COMMAND(NullCommand); +DEFINE_COMMAND(StartupCommand, true); +DEFINE_COMMAND(SimpleQueryCommand, true); +DEFINE_COMMAND(ParseCommand, false); +DEFINE_COMMAND(BindCommand, false); +DEFINE_COMMAND(DescribeCommand, false); +DEFINE_COMMAND(ExecuteCommand, false); +DEFINE_COMMAND(SyncCommand, true); +DEFINE_COMMAND(CloseCommand, false); +DEFINE_COMMAND(TerminateCommand, true); +DEFINE_COMMAND(NullCommand, true); } // namespace network } // namespace peloton diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index d5e476a41c3..1ca8584f344 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -14,13 +14,12 @@ #include "common/logger.h" #include "network/protocol_interpreter.h" #include "network/postgres_network_commands.h" -#include "network/netork_io_utils.h" +#include "network/network_io_utils.h" +#include "traffic_cop/tcop.h" -#define SSL_MESSAGE_VERNO 80877103 -#define PROTO_MAJOR_VERSION(x) ((x) >> 16) #define MAKE_COMMAND(type) \ std::static_pointer_cast( \ - std::make_shared(curr_input_packet_)) + std::make_shared(std::move(curr_input_packet_.buf_))) namespace peloton { namespace network { @@ -30,52 +29,30 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { inline Transition Process(std::shared_ptr &in, WriteQueue &out, size_t thread_id) override { - if (!BuildPacket(in)) return Transition::NEED_READ; + if (!TryBuildPacket(in)) return Transition::NEED_READ; std::shared_ptr command = PacketToCommand(); curr_input_packet_.Clear(); - return command->Exec(*this, out, thread_id); + PostgresPacketWriter writer(out); + if (command->FlushOnComplete()) out.ForceFlush(); + return command->Exec(*this, writer, thread_id); } - inline void AddCommandLineOption(const std::string &name, std::string val) { - cmdline_options_[name] = std::move(val); + inline void AddCmdlineOption(std::string key, std::string value) { + cmdline_options_[key] = std::move(value); } inline void FinishStartup() { startup_ = false; } - std::shared_ptr PacketToCommand() { -// if (startup_) return MAKE_COMMAND(StartupCommand); - switch (curr_input_packet_.msg_type_) { -// case NetworkMessageType::SIMPLE_QUERY_COMMAND: -// return MAKE_COMMAND(SimpleQueryCommand); -// case NetworkMessageType::PARSE_COMMAND: -// return MAKE_COMMAND(ParseCommand); -// case NetworkMessageType::BIND_COMMAND: -// return MAKE_COMMAND(BindCommand); -// case NetworkMessageType::DESCRIBE_COMMAND: -// return MAKE_COMMAND(DescribeCommand); -// case NetworkMessageType::EXECUTE_COMMAND: -// return MAKE_COMMAND(ExecuteCommand); -// case NetworkMessageType::SYNC_COMMAND: -// return MAKE_COMMAND(SyncCommand); -// case NetworkMessageType::CLOSE_COMMAND: -// return MAKE_COMMAND(CloseCommand); -// case NetworkMessageType::TERMINATE_COMMAND: -// return MAKE_COMMAND(TerminateCommand); -// case NetworkMessageType::NULL_COMMAND: -// return MAKE_COMMAND(NullCommand); - default: - throw NetworkProcessException("Unexpected Packet Type: " + - std::to_string(static_cast(curr_input_packet_.msg_type_))); - } - } + inline tcop::ClientProcessState &ClientProcessState() { return state_; } private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; std::unordered_map cmdline_options_; + tcop::ClientProcessState state_; - bool BuildPacket(std::shared_ptr &in) { - if (!ReadPacketHeader(in)) return false; + bool TryBuildPacket(std::shared_ptr &in) { + if (!TryReadPacketHeader(in)) return false; size_t size_needed = curr_input_packet_.extended_ ? curr_input_packet_.len_ @@ -83,12 +60,14 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { : curr_input_packet_.len_; if (!in->HasMore(size_needed)) return false; + // copy bytes only if the packet is longer than the read buffer, + // otherwise we can use the read buffer to save space if (curr_input_packet_.extended_) curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); return true; } - bool ReadPacketHeader(std::shared_ptr &in) { + bool TryReadPacketHeader(std::shared_ptr &in) { if (curr_input_packet_.header_parsed_) return true; // Header format: 1 byte message type (only if non-startup) @@ -117,212 +96,34 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { curr_input_packet_.header_parsed_ = true; return true; } -}; - -/** - * Wrapper around an I/O layer WriteQueue to provide Postgres-sprcific - * helper methods. - */ -class PostgresPacketWriter { - public: - /* - * Instantiates a new PostgresPacketWriter backed by the given WriteQueue - */ - PostgresPacketWriter(WriteQueue &write_queue) : queue_(write_queue) {} - - ~PostgresPacketWriter() { - // Make sure no packet is being written on destruction, otherwise we are - // malformed write buffer - PELOTON_ASSERT(curr_packet_len_ == nullptr); - } - /** - * Write out a packet with a single byte (e.g. SSL_YES or SSL_NO). This is a - * special case since no size field is provided. - * @param type Type of message to write out - */ - inline void WriteSingleBytePacket(NetworkMessageType type) { - // Make sure no active packet being constructed - PELOTON_ASSERT(curr_packet_len_ == nullptr); - queue_.BufferWriteRawValue(type); - } - - /** - * Begin writing a new packet. Caller can use other append methods to write - * contents to the packet. An explicit call to end packet must be made to - * make these writes valid. - * @param type - * @return self-reference for chaining - */ - PostgresPacketWriter &BeginPacket(NetworkMessageType type) { - // No active packet being constructed - PELOTON_ASSERT(curr_packet_len_ == nullptr); - queue_.BufferWriteRawValue(type); - // Remember the size field since we will need to modify it as we go along. - // It is important that our size field is contiguous and not broken between - // two buffers. - queue_.BufferWriteRawValue(0, false); - WriteBuffer &tail = *(queue_.buffers_[queue_.buffers_.size() - 1]); - curr_packet_len_ = - reinterpret_cast(&tail.buf_[tail.size_ - sizeof(int32_t)]); - return *this; - } - - /** - * Append raw bytes from specified memory location into the write queue. - * There must be a packet active in the writer. - * @param src memory location to write from - * @param len number of bytes to write - * @return self-reference for chaining - */ - inline PostgresPacketWriter &AppendRaw(const void *src, size_t len) { - PELOTON_ASSERT(curr_packet_len_ != nullptr); - queue_.BufferWriteRaw(src, len); - // Add the size field to the len of the packet. Be mindful of byte - // ordering. We switch to network ordering only when the packet is finished - *curr_packet_len_ += len; - return *this; - } - - /** - * Append a value onto the write queue. There must be a packet active in the - * writer. No byte order conversion is performed. It is up to the caller to - * do so if needed. - * @tparam T type of value to write - * @param val value to write - * @return self-reference for chaining - */ - template - inline PostgresPacketWriter &AppendRawValue(T val) { - return AppendRaw(&val, sizeof(T)); - } - - /** - * Append an integer of specified length onto the write queue. (1, 2, 4, or 8 - * bytes). It is assumed that these bytes need to be converted to network - * byte ordering. - * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. - * @param val value to write - * @return self-reference for chaining - */ - template - PostgresPacketWriter &AppendInt(T val) { - // We only want to allow for certain type sizes to be used - // After the static assert, the compiler should be smart enough to throw - // away the other cases and only leave the relevant return statement. - static_assert(sizeof(T) == 1 - || sizeof(T) == 2 - || sizeof(T) == 4 - || sizeof(T) == 8, "Invalid size for integer"); - switch (sizeof(T)) { - case 1: return AppendRawValue(val); - case 2: return AppendRawValue(ntohs(val)); - case 4: return AppendRawValue(ntohl(val)); - case 8: return AppendRawValue(ntohll(val)); - // Will never be here due to compiler optimization - default: throw NetworkProcessException(""); + std::shared_ptr PacketToCommand() { + if (startup_) return MAKE_COMMAND(StartupCommand); + switch (curr_input_packet_.msg_type_) { + case NetworkMessageType::SIMPLE_QUERY_COMMAND: + return MAKE_COMMAND(SimpleQueryCommand); + case NetworkMessageType::PARSE_COMMAND: + return MAKE_COMMAND(ParseCommand); + case NetworkMessageType::BIND_COMMAND: + return MAKE_COMMAND(BindCommand); + case NetworkMessageType::DESCRIBE_COMMAND: + return MAKE_COMMAND(DescribeCommand); + case NetworkMessageType::EXECUTE_COMMAND: + return MAKE_COMMAND(ExecuteCommand); + case NetworkMessageType::SYNC_COMMAND: + return MAKE_COMMAND(SyncCommand); + case NetworkMessageType::CLOSE_COMMAND: + return MAKE_COMMAND(CloseCommand); + case NetworkMessageType::TERMINATE_COMMAND: + return MAKE_COMMAND(TerminateCommand); + case NetworkMessageType::NULL_COMMAND: + return MAKE_COMMAND(NullCommand); + default: + throw NetworkProcessException("Unexpected Packet Type: " + + std::to_string(static_cast(curr_input_packet_.msg_type_))); } } - - /** - * Append a string onto the write queue. - * @param str the string to append - * @param nul_terminate whether the nul terminaor should be written as well - * @return self-reference for chaining - */ - inline PostgresPacketWriter &AppendString(const std::string &str, - bool nul_terminate = true) { - return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); - } - - /** - * End the packet. A packet write must be in progress and said write is not - * well-formed until this method is called. - */ - inline void EndPacket() { - PELOTON_ASSERT(curr_packet_len_ != nullptr); - // Switch to network byte ordering, add the 4 bytes of size field - *curr_packet_len_ = htonl(*curr_packet_len_ + sizeof(int32_t)); - curr_packet_len_ = nullptr; - } - private: - // We need to keep track of the size field of the current packet, - // so we can update it as more bytes are written into this packet. - uint32_t *curr_packet_len_ = nullptr; - // Underlying WriteQueue backing this writer - WriteQueue &queue_; }; -class PostgresWireUtilities { - public: - PostgresWireUtilities() = delete; - - static inline void SendErrorResponse( - PostgresPacketWriter &writer, - std::vector> error_status) { - writer.BeginPacket(NetworkMessageType::ERROR_RESPONSE); - for (const auto &entry : error_status) { - writer.AppendRawValue(entry.first); - writer.AppendString(entry.second); - } - // Nul-terminate packet - writer.AppendRawValue(0) - .EndPacket(); - } - - static inline void SendStartupResponse(PostgresPacketWriter &writer) { - // auth-ok - writer.BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST).EndPacket(); - - // parameter status map - for (auto &entry : parameter_status_map_) - writer.BeginPacket(NetworkMessageType::PARAMETER_STATUS) - .AppendString(entry.first) - .AppendString(entry.second) - .EndPacket(); - - // ready-for-query - SendReadyForQuery(writer, NetworkTransactionStateType::IDLE); - } - - static inline void SendReadyForQuery(PostgresPacketWriter &writer, - NetworkTransactionStateType txn_status) { - writer.BeginPacket(NetworkMessageType::READY_FOR_QUERY) - .AppendRawValue(txn_status) - .EndPacket(); - } - - static inline void SendEmptyQueryResponse(PostgresPacketWriter &writer) { - writer.BeginPacket(NetworkMessageType::EMPTY_QUERY_RESPONSE).EndPacket(); - } - - static inline void SendCommandCompleteResponse(PostgresPacketWriter &writer, - const QueryType &query_type, - int rows) { - std::string tag = QueryTypeToString(query_type); - switch (query_type) { - case QueryType::QUERY_INSERT:tag += " 0 " + std::to_string(rows); - break; - case QueryType::QUERY_BEGIN: - case QueryType::QUERY_COMMIT: - case QueryType::QUERY_ROLLBACK: - case QueryType::QUERY_CREATE_TABLE: - case QueryType::QUERY_CREATE_DB: - case QueryType::QUERY_CREATE_INDEX: - case QueryType::QUERY_CREATE_TRIGGER: - case QueryType::QUERY_PREPARE:break; - default:tag += " " + std::to_string(rows); - } - writer.BeginPacket(NetworkMessageType::COMMAND_COMPLETE) - .AppendString(tag) - .EndPacket(); - } - - private: - // TODO(Tianyu): It looks broken that this never changes. - // TODO(Tianyu): Also, Initialize. - static const std::unordered_map - parameter_status_map_; -}; } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h new file mode 100644 index 00000000000..84ad424fc89 --- /dev/null +++ b/src/include/network/postgres_protocol_utils.h @@ -0,0 +1,211 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_protocol_utils.h +// +// Identification: src/include/network/postgres_protocol_utils.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once +#include "network/network_io_utils.h" + +namespace peloton { +namespace network { + +// TODO(Tianyu): It looks very broken that this never changes. +// TODO(Tianyu): Also, Initialize. +const std::unordered_map + parameter_status_map; + +/** + * Encapsulates an input packet + */ +struct PostgresInputPacket { + NetworkMessageType msg_type_ = NetworkMessageType::NULL_COMMAND; + size_t len_ = 0; + std::shared_ptr buf_; + bool header_parsed_ = false, extended_ = false; + + PostgresInputPacket() = default; + PostgresInputPacket(const PostgresInputPacket &) = default; + PostgresInputPacket(PostgresInputPacket &&) = default; + + inline void Clear() { + msg_type_ = NetworkMessageType::NULL_COMMAND; + len_ = 0; + buf_ = nullptr; + header_parsed_ = false; + } +}; + +/** + * Wrapper around an I/O layer WriteQueue to provide Postgres-sprcific + * helper methods. + */ +class PostgresPacketWriter { + public: + /* + * Instantiates a new PostgresPacketWriter backed by the given WriteQueue + */ + PostgresPacketWriter(WriteQueue &write_queue) : queue_(write_queue) {} + + ~PostgresPacketWriter() { + // Make sure no packet is being written on destruction, otherwise we are + // malformed write buffer + PELOTON_ASSERT(curr_packet_len_ == nullptr); + } + + /** + * Write out a packet with a single byte (e.g. SSL_YES or SSL_NO). This is a + * special case since no size field is provided. + * @param type Type of message to write out + */ + inline void WriteSingleBytePacket(NetworkMessageType type) { + // Make sure no active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + queue_.BufferWriteRawValue(type); + } + + /** + * Begin writing a new packet. Caller can use other append methods to write + * contents to the packet. An explicit call to end packet must be made to + * make these writes valid. + * @param type + * @return self-reference for chaining + */ + PostgresPacketWriter &BeginPacket(NetworkMessageType type) { + // No active packet being constructed + PELOTON_ASSERT(curr_packet_len_ == nullptr); + queue_.BufferWriteRawValue(type); + // Remember the size field since we will need to modify it as we go along. + // It is important that our size field is contiguous and not broken between + // two buffers. + queue_.BufferWriteRawValue(0, false); + WriteBuffer &tail = *(queue_.buffers_[queue_.buffers_.size() - 1]); + curr_packet_len_ = + reinterpret_cast(&tail.buf_[tail.size_ - sizeof(int32_t)]); + return *this; + } + + /** + * Append raw bytes from specified memory location into the write queue. + * There must be a packet active in the writer. + * @param src memory location to write from + * @param len number of bytes to write + * @return self-reference for chaining + */ + inline PostgresPacketWriter &AppendRaw(const void *src, size_t len) { + PELOTON_ASSERT(curr_packet_len_ != nullptr); + queue_.BufferWriteRaw(src, len); + // Add the size field to the len of the packet. Be mindful of byte + // ordering. We switch to network ordering only when the packet is finished + *curr_packet_len_ += len; + return *this; + } + + /** + * Append a value onto the write queue. There must be a packet active in the + * writer. No byte order conversion is performed. It is up to the caller to + * do so if needed. + * @tparam T type of value to write + * @param val value to write + * @return self-reference for chaining + */ + template + inline PostgresPacketWriter &AppendRawValue(T val) { + return AppendRaw(&val, sizeof(T)); + } + + /** + * Append an integer of specified length onto the write queue. (1, 2, 4, or 8 + * bytes). It is assumed that these bytes need to be converted to network + * byte ordering. + * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. + * @param val value to write + * @return self-reference for chaining + */ + template + inline PostgresPacketWriter &AppendInt(T val) { + // We only want to allow for certain type sizes to be used + // After the static assert, the compiler should be smart enough to throw + // away the other cases and only leave the relevant return statement. + static_assert(sizeof(T) == 1 + || sizeof(T) == 2 + || sizeof(T) == 4 + || sizeof(T) == 8, "Invalid size for integer"); + switch (sizeof(T)) { + case 1: return AppendRawValue(val); + case 2: return AppendRawValue(ntohs(val)); + case 4: return AppendRawValue(ntohl(val)); + case 8: return AppendRawValue(ntohll(val)); + // Will never be here due to compiler optimization + default: throw NetworkProcessException(""); + } + } + + /** + * Append a string onto the write queue. + * @param str the string to append + * @param nul_terminate whether the nul terminaor should be written as well + * @return self-reference for chaining + */ + inline PostgresPacketWriter &AppendString(const std::string &str, + bool nul_terminate = true) { + return AppendRaw(str.data(), nul_terminate ? str.size() + 1 : str.size()); + } + + inline void WriteErrorResponse( + std::vector> error_status) { + BeginPacket(NetworkMessageType::ERROR_RESPONSE); + + for (const auto &entry : error_status) + AppendRawValue(entry.first) + .AppendString(entry.second); + + // Nul-terminate packet + AppendRawValue(0) + .EndPacket(); + } + + inline void WriteReadyForQuery(NetworkTransactionStateType txn_status) { + BeginPacket(NetworkMessageType::READY_FOR_QUERY) + .AppendRawValue(txn_status) + .EndPacket(); + } + + inline void WriteStartupResponse() { + BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST) + .EndPacket(); + + for (auto &entry : parameter_status_map) + BeginPacket(NetworkMessageType::PARAMETER_STATUS) + .AppendString(entry.first) + .AppendString(entry.second) + .EndPacket(); + WriteReadyForQuery(NetworkTransactionStateType::IDLE); + } + + /** + * End the packet. A packet write must be in progress and said write is not + * well-formed until this method is called. + */ + inline void EndPacket() { + PELOTON_ASSERT(curr_packet_len_ != nullptr); + // Switch to network byte ordering, add the 4 bytes of size field + *curr_packet_len_ = htonl(*curr_packet_len_ + sizeof(int32_t)); + curr_packet_len_ = nullptr; + } + private: + // We need to keep track of the size field of the current packet, + // so we can update it as more bytes are written into this packet. + uint32_t *curr_packet_len_ = nullptr; + // Underlying WriteQueue backing this writer + WriteQueue &queue_; +}; + +} // namespace network +} // namespace peloton diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 76d71f608c7..d7f462624b4 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -12,7 +12,7 @@ #pragma once #include #include "network/network_types.h" -#include "network/netork_io_utils.h" +#include "network/network_io_utils.h" namespace peloton { namespace network { diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 7e4a2346f9d..da13be044a7 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,12 +11,52 @@ //===----------------------------------------------------------------------===// #pragma once +#include "network/connection_handle.h" +#include "parser/postgresparser.h" +#include "parser/sql_statement.h" namespace peloton { namespace tcop { -struct JobHandle { - bool is_queueing_; + +// TODO(Tianyu): Probably need a better name +// TODO(Tianyu): We can probably get rid of a bunch of fields from here +struct ClientProcessState { + bool is_queuing_; + std::string error_message_, db_name_ = DEFAULT_DB_NAME; + std::vector param_values_; + std::vector results_; + // This save currnet statement in the traffic cop + std::shared_ptr statement_; + // Default database name + int rows_affected_; + // The optimizer used for this connection + std::unique_ptr optimizer_; + // flag of single statement txn + bool single_statement_txn_; + std::vector result_; + network::ConnectionHandle conn_handle_; }; +inline std::unique_ptr ParseQuery(ClientProcessState state, + const std::string &query) { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + // TODO(Tianyu): Parser result seems undocumented and I cannot tell + // at a glance what any of these mean + auto result = peloton_parser.BuildParseTree(query); + if (result != nullptr && !result->is_valid) + throw ParserException("Error parsing SQL statement"); + return std::move(result); +} + +std::shared_ptr PrepareStatement(ClientProcessState state, + const std::string &stmt_name, + const std::string &query_string, + std::unique_ptr sql_stmt_list) { + if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) + return std::make_shared(stmt_name, + QueryType::QUERY_INVALID, + query_string, std::move(sql_stmt_list)); + +} } // namespace tcop } // namespace peloton \ No newline at end of file diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index f94d690ddda..3a02391e01c 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -144,12 +144,12 @@ DEF_TRANSITION_GRAPH END_STATE_DEF DEFINE_STATE(CLOSING) - ON(WAKEUP) SET_STATE_TO(CLOSING) AND_INVOKE(TryCloseConnection) - ON(NEED_READ) SET_STATE_TO(WRITE) AND_WAIT_ON_READ - ON(NEED_WRITE) SET_STATE_TO(WRITE) AND_WAIT_ON_WRITE + ON(WAKEUP) SET_STATE_TO(CLOSING) AND_INVOKE(TryCloseConnection) + ON(NEED_READ) SET_STATE_TO(WRITE) AND_WAIT_ON_READ + ON(NEED_WRITE) SET_STATE_TO(WRITE) AND_WAIT_ON_WRITE END_STATE_DEF END_DEF - // clang-format on +// clang-format on void ConnectionHandle::StateMachine::Accept(Transition action, ConnectionHandle &connection) { @@ -194,21 +194,16 @@ Transition ConnectionHandle::Process() { ProtocolHandlerType::Postgres, &tcop_); ProcessResult status = protocol_handler_->Process( - *(io_wrapper_->rbuf_), (size_t)conn_handler_->Id()); + *(io_wrapper_->rbuf_), (size_t) conn_handler_->Id()); switch (status) { - case ProcessResult::MORE_DATA_REQUIRED: - return Transition::NEED_READ; - case ProcessResult::COMPLETE: - return Transition::PROCEED; - case ProcessResult::PROCESSING: - return Transition::NEED_RESULT; + case ProcessResult::MORE_DATA_REQUIRED:return Transition::NEED_READ; + case ProcessResult::COMPLETE:return Transition::PROCEED; + case ProcessResult::PROCESSING:return Transition::NEED_RESULT; case ProcessResult::TERMINATE: throw NetworkProcessException("Error when processing"); - case ProcessResult::NEED_SSL_HANDSHAKE: - return Transition::NEED_SSL_HANDSHAKE; - default: - LOG_ERROR("Unknown process result"); + case ProcessResult::NEED_SSL_HANDSHAKE:return Transition::NEED_SSL_HANDSHAKE; + default:LOG_ERROR("Unknown process result"); throw NetworkProcessException("Unknown process result"); } } diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 8741231c945..afe7fb3f7b0 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -13,6 +13,7 @@ #include "network/postgres_protocol_interpreter.h" #include "network/peloton_server.h" #include "network/postgres_network_commands.h" +#include "traffic_cop/tcop.h" #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) @@ -20,58 +21,60 @@ namespace peloton { namespace network { -//Transition StartupCommand::Exec(PostgresProtocolInterpreter &protocol_object, -// WriteQueue &out, -// size_t) { -// // Always flush startup response -// out.ForceFlush(); -// int32_t proto_version = input_packet_.buf_->ReadInt(); -// LOG_INFO("protocol version: %d", proto_version); -// if (proto_version == SSL_MESSAGE_VERNO) { -// // SSL Handshake initialization -// // TODO(Tianyu): This static method probably needs to be moved into -// // settings manager -// bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE); -// out.WriteSingleBytePacket(ssl_able -// ? NetworkMessageType::SSL_YES -// : NetworkMessageType::SSL_NO); -// return ssl_able ? Transition::NEED_SSL_HANDSHAKE : Transition::PROCEED; -// } else { -// // Normal Initialization -// if (PROTO_MAJOR_VERSION(proto_version) != 3) { -// // Only protocol version 3 is supported -// LOG_ERROR("Protocol error: Only protocol version 3 is supported."); -// PostgresWireUtilities::SendErrorResponse( -// out, {{NetworkMessageType::HUMAN_READABLE_ERROR, -// "Protocol Version Not Support"}}); -// return Transition::TERMINATE; -// } -// -// std::string token, value; -// // TODO(Yuchen): check for more malformed cases -// // Read out startup package info -// while (input_packet_.buf_->HasMore()) { -// token = input_packet_.buf_->ReadString(); -// LOG_TRACE("Option key is %s", token.c_str()); -// // TODO(Tianyu): Why does this commented out line need to be here? -// // if (!input_packet_.buf_->HasMore()) break; -// value = input_packet_.buf_->ReadString(); -// LOG_TRACE("Option value is %s", value.c_str()); -// // TODO(Tianyu): We never seem to use this crap? -// protocol_object.AddCommandLineOption(token, value); -// // TODO(Tianyu): Do this after we are done refactoring traffic cop -//// if (token.compare("database") == 0) { -//// traffic_cop_->SetDefaultDatabaseName(value); -//// } -// } -// -// // Startup Response, for now we do not do any authentication -// PostgresWireUtilities::SendStartupResponse(out); -// protocol_object.FinishStartup(); -// return Transition::PROCEED; -// } -//} +Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + size_t) { + auto proto_version = in_->ReadInt(); + LOG_INFO("protocol version: %d", proto_version); + // SSL initialization + if (proto_version == SSL_MESSAGE_VERNO) { + // TODO(Tianyu): Should this be moved from PelotonServer into settings? + if (PelotonServer::GetSSLLevel() == SSLLevel::SSL_DISABLE) { + out.WriteSingleBytePacket(NetworkMessageType::SSL_NO); + return Transition::PROCEED; + } + out.WriteSingleBytePacket(NetworkMessageType::SSL_YES); + return Transition::NEED_SSL_HANDSHAKE; + } + + // Process startup packet + if (PROTO_MAJOR_VERSION(proto_version) != 3) { + LOG_ERROR("Protocol error: only protocol version 3 is supported"); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + "Protocol Version Not Supported"}}); + return Transition::TERMINATE; + } + + while (in_->HasMore()) { + // TODO(Tianyu): We don't seem to really handle the other flags? + std::string key = in_->ReadString(), value = in_->ReadString(); + LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); + if (key == std::string("database")) + interpreter.ClientProcessState().db_name_ = value; + interpreter.AddCmdlineOption(std::move(key), std::move(value)); + } + + // TODO(Tianyu): Implement authentication. For now we always send AuthOK + out.WriteStartupResponse(); + interpreter.FinishStartup(); + return Transition::PROCEED; +} + +Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + size_t) { + // TODO(Tianyu): Figure out what skipped stmt does and maybe implement + std::string statement_name = in_->ReadString(), query = in_->ReadString(); + std::unique_ptr sql_stmt_list; + try { + sql_stmt_list = tcop::ParseQuery(interpreter.ClientProcessState(), query); + } catch (Exception &e) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); + return Transition::PROCEED; + } + auto statement = +} } // namespace network } // namespace peloton \ No newline at end of file From 1de49c99ff3b080dbe4e196bbeafe32cc5a7670a Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Sun, 24 Jun 2018 17:42:51 -0400 Subject: [PATCH 14/48] Hook I/O layer up with new code. Lay ground work to start porting old code over to the new structure. --- src/common/notifiable_task.cpp | 2 +- src/include/common/notifiable_task.h | 6 +- src/include/network/connection_handle.h | 46 +++---- src/include/network/network_io_utils.h | 24 ++-- src/include/network/network_io_wrappers.h | 32 +++-- .../network/postgres_network_commands.h | 6 +- .../network/postgres_protocol_interpreter.h | 99 ++----------- src/include/network/postgres_protocol_utils.h | 5 + src/include/network/protocol_handler.h | 1 - src/include/network/protocol_interpreter.h | 15 +- src/include/traffic_cop/tcop.h | 77 ++++++++--- src/network/connection_handle.cpp | 57 ++------ src/network/network_io_wrapper_factory.cpp | 6 +- src/network/network_io_wrappers.cpp | 130 +++++++----------- src/network/postgres_network_commands.cpp | 49 +++++-- src/network/postgres_protocol_interpreter.cpp | 101 ++++++++++++++ src/traffic_cop/tcop.cpp | 19 +++ 17 files changed, 355 insertions(+), 320 deletions(-) create mode 100644 src/network/postgres_protocol_interpreter.cpp create mode 100644 src/traffic_cop/tcop.cpp diff --git a/src/common/notifiable_task.cpp b/src/common/notifiable_task.cpp index b23d60a0e7d..c208ce69691 100644 --- a/src/common/notifiable_task.cpp +++ b/src/common/notifiable_task.cpp @@ -17,7 +17,7 @@ namespace peloton { -NotifiableTask::NotifiableTask(int task_id) : task_id_(task_id) { +NotifiableTask::NotifiableTask(size_t task_id) : task_id_(task_id) { base_ = EventUtil::EventBaseNew(); // For exiting a loop terminate_ = RegisterManualEvent([](int, short, void *arg) { diff --git a/src/include/common/notifiable_task.h b/src/include/common/notifiable_task.h index e1572ab63b9..ab2b8ae7633 100644 --- a/src/include/common/notifiable_task.h +++ b/src/include/common/notifiable_task.h @@ -49,7 +49,7 @@ class NotifiableTask { * Constructs a new NotifiableTask instance. * @param task_id a unique id assigned to this task */ - explicit NotifiableTask(int task_id); + explicit NotifiableTask(size_t task_id); /** * Destructs this NotifiableTask. All events currently registered to its base @@ -60,7 +60,7 @@ class NotifiableTask { /** * @return unique id assigned to this task */ - inline int Id() const { return task_id_; } + inline size_t Id() const { return task_id_; } /** * @brief Register an event with the event base associated with this @@ -183,7 +183,7 @@ class NotifiableTask { inline void ExitLoop(int, short) { ExitLoop(); } private: - const int task_id_; + const size_t task_id_; struct event_base *base_; // struct event and lifecycle management diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 08467b95ab3..e0fab366f6f 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -33,8 +33,9 @@ #include "marshal.h" #include "network/connection_handler_task.h" #include "network/network_io_wrappers.h" -#include "network_types.h" -#include "protocol_handler.h" +#include "network/network_types.h" +#include "network/protocol_interpreter.h" +#include "network/postgres_protocol_interpreter.h" #include #include @@ -71,15 +72,6 @@ class ConnectionHandle { workpool_event_ = conn_handler_->RegisterManualEvent( METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); - // TODO(Tianyi): should put the initialization else where.. check - // correctness first. - tcop_.SetTaskCallback( - [](void *arg) { - struct event *event = static_cast(arg); - event_active(event, EV_WRITE, 0); - }, - workpool_event_); - network_event_ = conn_handler_->RegisterEvent( io_wrapper_->GetSocketFd(), EV_READ | EV_PERSIST, METHOD_AS_CALLBACK(ConnectionHandle, HandleEvent), this); @@ -95,8 +87,19 @@ class ConnectionHandle { /* State Machine Actions */ // TODO(Tianyu): Write some documentation when feeling like it inline Transition TryRead() { return io_wrapper_->FillReadBuffer(); } - Transition TryWrite(); - Transition Process(); + + inline Transition TryWrite() { + if (io_wrapper_->ShouldFlush()) + return io_wrapper_->FlushAllWrites(); + } + + inline Transition Process() { + return protocol_interpreter_-> + Process(io_wrapper_->GetReadBuffer(), + io_wrapper_->GetWriteQueue(), + [=] { event_active(workpool_event_, EV_WRITE, 0); }); + } + Transition GetResult(); Transition TrySslHandshake(); Transition TryCloseConnection(); @@ -176,25 +179,12 @@ class ConnectionHandle { friend class StateMachine; friend class NetworkIoWrapperFactory; - /** - * @brief: Determine if there is still responses in the buffer - * @return true if there is still responses to flush out in either wbuf or - * responses - */ - inline bool HasResponse() { - return (protocol_handler_->responses_.size() != 0) || - (io_wrapper_->wbuf_->size_ != 0); - } - ConnectionHandlerTask *conn_handler_; std::shared_ptr io_wrapper_; StateMachine state_machine_; struct event *network_event_ = nullptr, *workpool_event_ = nullptr; - std::unique_ptr protocol_handler_ = nullptr; - // TODO(Tianyu): Remove tcop from here in later refactor - tcop::TrafficCop tcop_; - // TODO(Tianyu): Put this into protocol handler in a later refactor - unsigned int next_response_ = 0; + // TODO(Tianyu): Probably use a factory for this + std::unique_ptr protocol_interpreter_; }; } // namespace network } // namespace peloton diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index f953650c02a..c3703c58f0b 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -13,7 +13,6 @@ #pragma once #include #include - #include #include #include "common/internal_types.h" @@ -182,19 +181,6 @@ class ReadBuffer : public Buffer { } } - /** - * Read a block of bytes off the read buffer as a string. - * @param len Length of the string, inclusive of nul-terminator - * @return string of specified length at head of read buffer - */ - std::string ReadString(size_t len) { - if (len == 0) throw NetworkProcessException("Unexpected string size: 0"); - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + offset_ + (len - 1)); - offset_ += len; - return result; - } - /** * Read a nul-terminated string off the read buffer, or throw an exception * if no nul-terminator is found within packet range. @@ -330,6 +316,7 @@ class WriteQueue { */ inline void Reset() { buffers_.resize(1); + offset_ = 0; flush_ = false; if (buffers_[0] == nullptr) buffers_[0] = std::make_shared(); @@ -337,6 +324,13 @@ class WriteQueue { buffers_[0]->Reset(); } + inline std::shared_ptr FlushHead() { + if (buffers_.size() > offset_) return buffers_[offset_]; + return nullptr; + } + + inline void MarkHeadFlushed() { offset_++; } + /** * Force this WriteQueue to be flushed next time the network layer * is available to do so. @@ -388,8 +382,8 @@ class WriteQueue { } private: - friend class PostgresPacketWriter; std::vector> buffers_; + size_t offset_ = 0; bool flush_ = false; }; diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index c091a1d37d5..661002c5979 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -37,25 +37,27 @@ namespace network { */ class NetworkIoWrapper { friend class NetworkIoWrapperFactory; - public: virtual bool SslAble() const = 0; // TODO(Tianyu): Change and document after we refactor protocol handler virtual Transition FillReadBuffer() = 0; - virtual Transition FlushWriteBuffer() = 0; + virtual Transition FlushWriteBuffer(WriteBuffer &wbuf) = 0; virtual Transition Close() = 0; inline int GetSocketFd() { return sock_fd_; } - Transition WritePacket(OutputPacket *pkt); + inline std::shared_ptr GetReadBuffer() { return in_; } + inline std::shared_ptr GetWriteQueue() { return out_; } + Transition FlushAllWrites(); + inline bool ShouldFlush() { return out_->ShouldFlush(); } // TODO(Tianyu): Make these protected when protocol handler refactor is // complete - NetworkIoWrapper(int sock_fd, std::shared_ptr &rbuf, - std::shared_ptr &wbuf) + NetworkIoWrapper(int sock_fd, std::shared_ptr &in, + std::shared_ptr &out) : sock_fd_(sock_fd), - rbuf_(std::move(rbuf)), - wbuf_(std::move(wbuf)) { - rbuf_->Reset(); - wbuf_->Reset(); + in_(std::move(in)), + out_(std::move(out)) { + in_->Reset(); + out_->Reset(); } DISALLOW_COPY(NetworkIoWrapper) @@ -63,8 +65,8 @@ class NetworkIoWrapper { NetworkIoWrapper(NetworkIoWrapper &&other) = default; int sock_fd_; - std::shared_ptr rbuf_; - std::shared_ptr wbuf_; + std::shared_ptr in_; + std::shared_ptr out_; }; /** @@ -72,13 +74,13 @@ class NetworkIoWrapper { */ class PosixSocketIoWrapper : public NetworkIoWrapper { public: - PosixSocketIoWrapper(int sock_fd, std::shared_ptr rbuf, - std::shared_ptr wbuf); + PosixSocketIoWrapper(int sock_fd, std::shared_ptr in, + std::shared_ptr out); inline bool SslAble() const override { return false; } Transition FillReadBuffer() override; - Transition FlushWriteBuffer() override; + Transition FlushWriteBuffer(WriteBuffer &wbuf) override; inline Transition Close() override { peloton_close(sock_fd_); return Transition::PROCEED; @@ -97,7 +99,7 @@ class SslSocketIoWrapper : public NetworkIoWrapper { inline bool SslAble() const override { return true; } Transition FillReadBuffer() override; - Transition FlushWriteBuffer() override; + Transition FlushWriteBuffer(WriteBuffer &wbuf) override; Transition Close() override; private: diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 70fcce7bf05..4e1f8f65a1d 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -26,7 +26,7 @@ class name : public PostgresNetworkCommand { \ : PostgresNetworkCommand(std::move(in), flush) {} \ virtual Transition Exec(PostgresProtocolInterpreter &, \ PostgresPacketWriter &, \ - size_t) override; \ + callback_func, size_t) override; \ } namespace peloton { @@ -36,11 +36,13 @@ class PostgresProtocolInterpreter; class PostgresNetworkCommand { public: - virtual Transition Exec(PostgresProtocolInterpreter &protocol_obj, + virtual Transition Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, + callback_func callback, size_t thread_id) = 0; inline bool FlushOnComplete() { return flush_on_complete_; } + protected: explicit PostgresNetworkCommand(std::shared_ptr in, bool flush) : in_(std::move(in)), flush_on_complete_(flush) {} diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 1ca8584f344..6d811f15055 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -2,7 +2,7 @@ // // Peloton // -// postgres_wire_protocol.h +// postgres_protocol_interpreter.h // // Identification: src/include/network/postgres_wire_protocol.h // @@ -14,7 +14,6 @@ #include "common/logger.h" #include "network/protocol_interpreter.h" #include "network/postgres_network_commands.h" -#include "network/network_io_utils.h" #include "traffic_cop/tcop.h" #define MAKE_COMMAND(type) \ @@ -26,18 +25,17 @@ namespace network { class PostgresProtocolInterpreter : public ProtocolInterpreter { public: - inline Transition Process(std::shared_ptr &in, - WriteQueue &out, - size_t thread_id) override { - if (!TryBuildPacket(in)) return Transition::NEED_READ; - std::shared_ptr command = PacketToCommand(); - curr_input_packet_.Clear(); - PostgresPacketWriter writer(out); - if (command->FlushOnComplete()) out.ForceFlush(); - return command->Exec(*this, writer, thread_id); - } + // TODO(Tianyu): Is this even the right thread id? It seems that all the + // concurrency code is dependent on this number. + PostgresProtocolInterpreter(size_t thread_id) = default; + + Transition Process(std::shared_ptr in, + std::shared_ptr out, + callback_func callback) override; - inline void AddCmdlineOption(std::string key, std::string value) { + inline void GetResult() override {} + + inline void AddCmdlineOption(const std::string &key, std::string value) { cmdline_options_[key] = std::move(value); } @@ -51,78 +49,9 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { std::unordered_map cmdline_options_; tcop::ClientProcessState state_; - bool TryBuildPacket(std::shared_ptr &in) { - if (!TryReadPacketHeader(in)) return false; - - size_t size_needed = curr_input_packet_.extended_ - ? curr_input_packet_.len_ - - curr_input_packet_.buf_->BytesAvailable() - : curr_input_packet_.len_; - if (!in->HasMore(size_needed)) return false; - - // copy bytes only if the packet is longer than the read buffer, - // otherwise we can use the read buffer to save space - if (curr_input_packet_.extended_) - curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); - return true; - } - - bool TryReadPacketHeader(std::shared_ptr &in) { - if (curr_input_packet_.header_parsed_) return true; - - // Header format: 1 byte message type (only if non-startup) - // + 4 byte message size (inclusive of these 4 bytes) - size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); - // Make sure the entire header is readable - if (!in->HasMore(header_size)) return false; - - // The header is ready to be read, fill in fields accordingly - if (!startup_) - curr_input_packet_.msg_type_ = in->ReadRawValue(); - curr_input_packet_.len_ = in->ReadInt() - sizeof(uint32_t); - - // Extend the buffer as needed - if (curr_input_packet_.len_ > in->Capacity()) { - LOG_INFO("Extended Buffer size required for packet of size %ld", - curr_input_packet_.len_); - // Allocate a larger buffer and copy bytes off from the I/O layer's buffer - curr_input_packet_.buf_ = - std::make_shared(curr_input_packet_.len_); - curr_input_packet_.extended_ = true; - } else { - curr_input_packet_.buf_ = in; - } - - curr_input_packet_.header_parsed_ = true; - return true; - } - - std::shared_ptr PacketToCommand() { - if (startup_) return MAKE_COMMAND(StartupCommand); - switch (curr_input_packet_.msg_type_) { - case NetworkMessageType::SIMPLE_QUERY_COMMAND: - return MAKE_COMMAND(SimpleQueryCommand); - case NetworkMessageType::PARSE_COMMAND: - return MAKE_COMMAND(ParseCommand); - case NetworkMessageType::BIND_COMMAND: - return MAKE_COMMAND(BindCommand); - case NetworkMessageType::DESCRIBE_COMMAND: - return MAKE_COMMAND(DescribeCommand); - case NetworkMessageType::EXECUTE_COMMAND: - return MAKE_COMMAND(ExecuteCommand); - case NetworkMessageType::SYNC_COMMAND: - return MAKE_COMMAND(SyncCommand); - case NetworkMessageType::CLOSE_COMMAND: - return MAKE_COMMAND(CloseCommand); - case NetworkMessageType::TERMINATE_COMMAND: - return MAKE_COMMAND(TerminateCommand); - case NetworkMessageType::NULL_COMMAND: - return MAKE_COMMAND(NullCommand); - default: - throw NetworkProcessException("Unexpected Packet Type: " + - std::to_string(static_cast(curr_input_packet_.msg_type_))); - } - } + bool TryBuildPacket(std::shared_ptr &in); + bool TryReadPacketHeader(std::shared_ptr &in); + std::shared_ptr PacketToCommand(); }; } // namespace network diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index 84ad424fc89..ab26741420c 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -189,6 +189,11 @@ class PostgresPacketWriter { WriteReadyForQuery(NetworkTransactionStateType::IDLE); } + inline void WriteEmptyQueryResponse() { + BeginPacket(NetworkMessageType::EMPTY_QUERY_RESPONSE) + .EndPacket(); + } + /** * End the packet. A packet write must be in progress and said write is not * well-formed until this method is called. diff --git a/src/include/network/protocol_handler.h b/src/include/network/protocol_handler.h index 0a7ccef3898..785626e4b16 100644 --- a/src/include/network/protocol_handler.h +++ b/src/include/network/protocol_handler.h @@ -18,7 +18,6 @@ // Packet content macros namespace peloton { - namespace network { typedef std::vector> ResponseBuffer; diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index d7f462624b4..98cdeeaa046 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -11,19 +11,26 @@ //===----------------------------------------------------------------------===// #pragma once #include +#include #include "network/network_types.h" #include "network/network_io_utils.h" namespace peloton { namespace network { +using callback_func = std::function; class ProtocolInterpreter { public: - // TODO(Tianyu): What the hell is this thread_id thingy - virtual Transition Process(std::shared_ptr &in, - WriteQueue &out, - size_t thread_id) = 0; + ProtocolInterpreter(size_t thread_id) : thread_id_(thread_id) {} + virtual Transition Process(std::shared_ptr in, + std::shared_ptr out, + callback_func callback) = 0; + + // TODO(Tianyu): Do we really need this crap? + virtual void GetResult() = 0; + protected: + size_t thread_id_; }; } // namespace network diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index da13be044a7..ef21d0c978f 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,15 +11,21 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include "network/connection_handle.h" #include "parser/postgresparser.h" #include "parser/sql_statement.h" namespace peloton { namespace tcop { +// pair of txn ptr and the result so-far for that txn +// use a stack to support nested-txns +using TcopTxnState = std::pair; + // TODO(Tianyu): Probably need a better name // TODO(Tianyu): We can probably get rid of a bunch of fields from here struct ClientProcessState { + size_t thread_id_; bool is_queuing_; std::string error_message_, db_name_ = DEFAULT_DB_NAME; std::vector param_values_; @@ -33,30 +39,57 @@ struct ClientProcessState { // flag of single statement txn bool single_statement_txn_; std::vector result_; - network::ConnectionHandle conn_handle_; + std::stack tcop_txn_state_; + executor::ExecutionResult p_status_; }; -inline std::unique_ptr ParseQuery(ClientProcessState state, - const std::string &query) { - auto &peloton_parser = parser::PostgresParser::GetInstance(); - // TODO(Tianyu): Parser result seems undocumented and I cannot tell - // at a glance what any of these mean - auto result = peloton_parser.BuildParseTree(query); - if (result != nullptr && !result->is_valid) - throw ParserException("Error parsing SQL statement"); - return std::move(result); -} - -std::shared_ptr PrepareStatement(ClientProcessState state, - const std::string &stmt_name, - const std::string &query_string, - std::unique_ptr sql_stmt_list) { - if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) - return std::make_shared(stmt_name, - QueryType::QUERY_INVALID, - query_string, std::move(sql_stmt_list)); - -} +// Execute a statement +ResultType ExecuteStatement( + const std::shared_ptr &statement, + const std::vector ¶ms, const bool unnamed, + std::shared_ptr param_stats, + const std::vector &result_format, std::vector &result, + size_t thread_id = 0); + +// Helper to handle txn-specifics for the plan-tree of a statement. +executor::ExecutionResult ExecuteHelper( + std::shared_ptr plan, + const std::vector ¶ms, std::vector &result, + const std::vector &result_format, size_t thread_id = 0); + +// Prepare a statement using the parse tree +std::shared_ptr PrepareStatement( + const std::string &statement_name, const std::string &query_string, + std::unique_ptr sql_stmt_list, + size_t thread_id = 0); + +bool BindParamsForCachePlan( + const std::vector> &, + const size_t thread_id = 0); + +std::vector GenerateTupleDescriptor( + parser::SQLStatement *select_stmt); + +FieldInfo GetColumnFieldForValueType(std::string column_name, + type::TypeId column_type); + +ResultType CommitQueryHelper(); + +void ExecuteStatementPlanGetResult(); + +ResultType ExecuteStatementGetResult(); + +void ProcessInvalidStatement(ClientProcessState &state); + +ResultType BeginQueryHelper(size_t thread_id); + +ResultType AbortQueryHelper(); + +// Get all data tables from a TableRef. +// For multi-way join +// still a HACK +void GetTableColumns(parser::TableRef *from_table, + std::vector &target_tables); } // namespace tcop } // namespace peloton \ No newline at end of file diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index 3a02391e01c..8087d5b085e 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -129,8 +129,8 @@ DEF_TRANSITION_GRAPH ON(WAKEUP) SET_STATE_TO(PROCESS) AND_INVOKE(GetResult) ON(PROCEED) SET_STATE_TO(WRITE) AND_INVOKE(TryWrite) ON(NEED_READ) SET_STATE_TO(READ) AND_INVOKE(TryRead) - // Client connections are ignored while we wait on peloton - // to execute the query + // Client connections are ignored while we wait on peloton + // to execute the query ON(NEED_RESULT) SET_STATE_TO(PROCESS) AND_WAIT_ON_PELOTON ON(NEED_SSL_HANDSHAKE) SET_STATE_TO(SSL_INIT) AND_INVOKE(TrySslHandshake) END_STATE_DEF @@ -166,61 +166,22 @@ void ConnectionHandle::StateMachine::Accept(Transition action, } } +// TODO(Tianyu): Maybe use a factory to initialize protocol_interpreter here ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) : conn_handler_(handler), - io_wrapper_(NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd)) {} - -Transition ConnectionHandle::TryWrite() { - for (; next_response_ < protocol_handler_->responses_.size(); - next_response_++) { - auto result = io_wrapper_->WritePacket( - protocol_handler_->responses_[next_response_].get()); - if (result != Transition::PROCEED) return result; - } - protocol_handler_->responses_.clear(); - next_response_ = 0; - if (protocol_handler_->GetFlushFlag()) return io_wrapper_->FlushWriteBuffer(); - protocol_handler_->SetFlushFlag(false); - return Transition::PROCEED; -} - -Transition ConnectionHandle::Process() { - // TODO(Tianyu): Just use Transition instead of ProcessResult, this looks - // like a 1 - 1 mapping between the two types. - if (protocol_handler_ == nullptr) - // TODO(Tianyi) Check the rbuf here before we create one if we have - // another protocol handler - protocol_handler_ = ProtocolHandlerFactory::CreateProtocolHandler( - ProtocolHandlerType::Postgres, &tcop_); - - ProcessResult status = protocol_handler_->Process( - *(io_wrapper_->rbuf_), (size_t) conn_handler_->Id()); - - switch (status) { - case ProcessResult::MORE_DATA_REQUIRED:return Transition::NEED_READ; - case ProcessResult::COMPLETE:return Transition::PROCEED; - case ProcessResult::PROCESSING:return Transition::NEED_RESULT; - case ProcessResult::TERMINATE: - throw NetworkProcessException("Error when processing"); - case ProcessResult::NEED_SSL_HANDSHAKE:return Transition::NEED_SSL_HANDSHAKE; - default:LOG_ERROR("Unknown process result"); - throw NetworkProcessException("Unknown process result"); - } -} + io_wrapper_(NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd)), + protocol_interpreter_{new PostgresProtocolInterpreter(conn_handler_->Id())} {} Transition ConnectionHandle::GetResult() { EventUtil::EventAdd(network_event_, nullptr); - protocol_handler_->GetResult(); - tcop_.SetQueuing(false); + protocol_interpreter_->GetResult(); return Transition::PROCEED; } Transition ConnectionHandle::TrySslHandshake() { - // Flush out all the response first - if (HasResponse()) { - auto write_ret = TryWrite(); - if (write_ret != Transition::PROCEED) return write_ret; - } + // TODO(Tianyu): Do we really need to flush here? + auto ret = io_wrapper_->FlushAllWrites(); + if (ret != Transition::PROCEED) return ret; return NetworkIoWrapperFactory::GetInstance().TryUseSsl( io_wrapper_); } diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp index 88efad09216..59ab149ba64 100644 --- a/src/network/network_io_wrapper_factory.cpp +++ b/src/network/network_io_wrapper_factory.cpp @@ -22,7 +22,7 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( // No reusable wrappers auto wrapper = std::make_shared( conn_fd, std::make_shared(), - std::make_shared()); + std::make_shared()); reusable_wrappers_[conn_fd] = std::static_pointer_cast( wrapper); @@ -35,8 +35,8 @@ std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( // constructor so the flags are set properly on the new file descriptor. auto &reused_wrapper = it->second; reused_wrapper = std::make_shared(conn_fd, - reused_wrapper->rbuf_, - reused_wrapper->wbuf_); + reused_wrapper->in_, + reused_wrapper->out_); return reused_wrapper; } diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index b914fd06051..b4e881dfe3f 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -19,46 +19,22 @@ namespace peloton { namespace network { -Transition NetworkIoWrapper::WritePacket(OutputPacket *pkt) { - // Write Packet Header - if (!pkt->skip_header_write) { - if (!wbuf_->HasSpaceFor(1 + sizeof(int32_t))) { - auto result = FlushWriteBuffer(); - if (FlushWriteBuffer() != Transition::PROCEED) - // Unable to flush buffer, socket presumably not ready for write - return result; - } - - wbuf_->AppendRaw(static_cast(pkt->msg_type)); - if (!pkt->single_type_pkt) - // Need to convert bytes to network order - wbuf_->AppendRaw(htonl(pkt->len + sizeof(int32_t))); - pkt->skip_header_write = true; - } - - // Write Packet Content - for (size_t len = pkt->len; len != 0;) { - if (wbuf_->HasSpaceFor(len)) { - wbuf_->AppendRaw(std::begin(pkt->buf) + pkt->write_ptr, len); - break; - } else { - auto write_size = wbuf_->RemainingCapacity(); - wbuf_->AppendRaw(std::begin(pkt->buf) + pkt->write_ptr, write_size); - len -= write_size; - pkt->write_ptr += write_size; - auto result = FlushWriteBuffer(); - if (FlushWriteBuffer() != Transition::PROCEED) - // Unable to flush buffer, socket presumably not ready for write - return result; - } +Transition NetworkIoWrapper::FlushAllWrites() { + for (auto buffer = out_->FlushHead(); + buffer != nullptr; + buffer = out_->FlushHead()) { + auto result = FlushWriteBuffer(*buffer); + if (result != Transition::PROCEED) return result; + out_->MarkHeadFlushed(); } + out_->Reset(); return Transition::PROCEED; } PosixSocketIoWrapper::PosixSocketIoWrapper(int sock_fd, - std::shared_ptr rbuf, - std::shared_ptr wbuf) - : NetworkIoWrapper(sock_fd, rbuf, wbuf) { + std::shared_ptr in, + std::shared_ptr out) + : NetworkIoWrapper(sock_fd, in, out) { // Set Non Blocking auto flags = fcntl(sock_fd_, F_GETFL); flags |= O_NONBLOCK; @@ -71,12 +47,12 @@ PosixSocketIoWrapper::PosixSocketIoWrapper(int sock_fd, } Transition PosixSocketIoWrapper::FillReadBuffer() { - if (!rbuf_->HasMore()) rbuf_->Reset(); - if (rbuf_->HasMore() && rbuf_->Full()) rbuf_->MoveContentToHead(); + if (!in_->HasMore()) in_->Reset(); + if (in_->HasMore() && in_->Full()) in_->MoveContentToHead(); Transition result = Transition::NEED_READ; // Normal mode - while (!rbuf_->Full()) { - auto bytes_read = rbuf_->FillBufferFrom(sock_fd_); + while (!in_->Full()) { + auto bytes_read = in_->FillBufferFrom(sock_fd_); if (bytes_read > 0) result = Transition::PROCEED; else if (bytes_read == 0) @@ -86,52 +62,44 @@ Transition PosixSocketIoWrapper::FillReadBuffer() { case EAGAIN: // Equal to EWOULDBLOCK return result; - case EINTR: - continue; - default: - LOG_ERROR("Error writing: %s", strerror(errno)); + case EINTR:continue; + default:LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Error when filling read buffer " + - std::to_string(errno)); + std::to_string(errno)); } } return result; } -Transition PosixSocketIoWrapper::FlushWriteBuffer() { - while (wbuf_->HasMore()) { - auto bytes_written = wbuf_->WriteOutTo(sock_fd_); - if (bytes_written < 0) switch (errno) { - case EINTR: - continue; - case EAGAIN: - return Transition::NEED_WRITE; - default: - LOG_ERROR("Error writing: %s", strerror(errno)); +Transition PosixSocketIoWrapper::FlushWriteBuffer(WriteBuffer &wbuf) { + while (wbuf.HasMore()) { + auto bytes_written = wbuf.WriteOutTo(sock_fd_); + if (bytes_written < 0) + switch (errno) { + case EINTR:continue; + case EAGAIN:return Transition::NEED_WRITE; + default:LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Fatal error during write"); } } - wbuf_->Reset(); + wbuf.Reset(); return Transition::PROCEED; } Transition SslSocketIoWrapper::FillReadBuffer() { - if (!rbuf_->HasMore()) rbuf_->Reset(); - if (rbuf_->HasMore() && rbuf_->Full()) rbuf_->MoveContentToHead(); + if (!in_->HasMore()) in_->Reset(); + if (in_->HasMore() && in_->Full()) in_->MoveContentToHead(); Transition result = Transition::NEED_READ; - while (!rbuf_->Full()) { - auto ret = rbuf_->FillBufferFrom(conn_ssl_context_); + while (!in_->Full()) { + auto ret = in_->FillBufferFrom(conn_ssl_context_); switch (ret) { - case SSL_ERROR_NONE: - result = Transition::PROCEED; + case SSL_ERROR_NONE:result = Transition::PROCEED; break; - case SSL_ERROR_ZERO_RETURN: - return Transition::TERMINATE; + case SSL_ERROR_ZERO_RETURN:return Transition::TERMINATE; // The SSL packet is partially loaded to the SSL buffer only, // More data is required in order to decode the wh`ole packet. - case SSL_ERROR_WANT_READ: - return result; - case SSL_ERROR_WANT_WRITE: - return Transition::NEED_WRITE; + case SSL_ERROR_WANT_READ:return result; + case SSL_ERROR_WANT_WRITE:return Transition::NEED_WRITE; case SSL_ERROR_SYSCALL: if (errno == EINTR) { LOG_INFO("Error SSL Reading: EINTR"); @@ -145,16 +113,13 @@ Transition SslSocketIoWrapper::FillReadBuffer() { return result; } -Transition SslSocketIoWrapper::FlushWriteBuffer() { - while (wbuf_->HasMore()) { - auto ret = wbuf_->WriteOutTo(conn_ssl_context_); +Transition SslSocketIoWrapper::FlushWriteBuffer(WriteBuffer &wbuf) { + while (wbuf.HasMore()) { + auto ret = wbuf.WriteOutTo(conn_ssl_context_); switch (ret) { - case SSL_ERROR_NONE: - break; - case SSL_ERROR_WANT_WRITE: - return Transition::NEED_WRITE; - case SSL_ERROR_WANT_READ: - return Transition::NEED_READ; + case SSL_ERROR_NONE:break; + case SSL_ERROR_WANT_WRITE:return Transition::NEED_WRITE; + case SSL_ERROR_WANT_READ:return Transition::NEED_READ; case SSL_ERROR_SYSCALL: // If interrupted, try again. if (errno == EINTR) { @@ -162,12 +127,13 @@ Transition SslSocketIoWrapper::FlushWriteBuffer() { break; } // Intentional Fallthrough - default: - LOG_ERROR("SSL write error: %d, error code: %lu", ret, ERR_get_error()); + default:LOG_ERROR("SSL write error: %d, error code: %lu", + ret, + ERR_get_error()); throw NetworkProcessException("SSL write error"); } } - wbuf_->Reset(); + wbuf.Reset(); return Transition::PROCEED; } @@ -177,13 +143,11 @@ Transition SslSocketIoWrapper::Close() { if (ret != 0) { int err = SSL_get_error(conn_ssl_context_, ret); switch (err) { - case SSL_ERROR_WANT_WRITE: - return Transition::NEED_WRITE; + case SSL_ERROR_WANT_WRITE:return Transition::NEED_WRITE; case SSL_ERROR_WANT_READ: // More work to do before shutdown return Transition::NEED_READ; - default: - LOG_ERROR("Error shutting down ssl session, err: %d", err); + default:LOG_ERROR("Error shutting down ssl session, err: %d", err); } } // SSL context is explicitly deallocated here because socket wrapper diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index afe7fb3f7b0..529c2297259 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -20,9 +20,9 @@ namespace peloton { namespace network { - Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, + callback_func, size_t) { auto proto_version = in_->ReadInt(); LOG_INFO("protocol version: %d", proto_version); @@ -51,7 +51,7 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); if (key == std::string("database")) interpreter.ClientProcessState().db_name_ = value; - interpreter.AddCmdlineOption(std::move(key), std::move(value)); + interpreter.AddCmdlineOption(key, std::move(value)); } // TODO(Tianyu): Implement authentication. For now we always send AuthOK @@ -60,20 +60,49 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } -Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, - PostgresPacketWriter &out, - size_t) { - // TODO(Tianyu): Figure out what skipped stmt does and maybe implement - std::string statement_name = in_->ReadString(), query = in_->ReadString(); +Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + callback_func callback, + size_t tid) { + std::string query = in_->ReadString(); + LOG_TRACE("Execute query: %s", query.c_str()); std::unique_ptr sql_stmt_list; try { - sql_stmt_list = tcop::ParseQuery(interpreter.ClientProcessState(), query); - } catch (Exception &e) { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + sql_stmt_list = peloton_parser.BuildParseTree(query); + + // When the query is empty(such as ";" or ";;", still valid), + // the pare tree is empty, parser will return nullptr. + if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { + throw ParserException("Error Parsing SQL statement"); + } + } // If the statement is invalid or not supported yet + catch (Exception &e) { + tcop::ProcessInvalidStatement(interpreter.ClientProcessState()); out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + if (sql_stmt_list.get() == nullptr || + sql_stmt_list->GetNumStatements() == 0) { + out.WriteEmptyQueryResponse(); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - auto statement = + // TODO(Yuchen): Hack. We only process the first statement in the packet now. + // We should store the rest of statements that will not be processed right + // away. For the hack, in most cases, it works. Because for example in psql, + // one packet contains only one query. But when using the pipeline mode in + // Libpqxx, it sends multiple query in one packet. In this case, it's + // incorrect. + auto sql_stmt = sql_stmt_list->PassOutStatement(0); + + QueryType query_type = + StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); + interpreter.ClientProcessState().protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; + } } // namespace network diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp new file mode 100644 index 00000000000..f036334bb46 --- /dev/null +++ b/src/network/postgres_protocol_interpreter.cpp @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// postgres_wire_protocol.h +// +// Identification: src/include/network/postgres_wire_protocol.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once +#include "network/postgres_protocol_interpreter.h" + +#define MAKE_COMMAND(type) \ + std::static_pointer_cast( \ + std::make_shared(std::move(curr_input_packet_.buf_))) + +namespace peloton { +namespace network { +Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, + std::shared_ptr out, + callback_func callback) { + if (!TryBuildPacket(in)) return Transition::NEED_READ; + std::shared_ptr command = PacketToCommand(); + curr_input_packet_.Clear(); + PostgresPacketWriter writer(*out); + if (command->FlushOnComplete()) out->ForceFlush(); + return command->Exec(*this, writer, callback, thread_id_); +} + +bool PostgresProtocolInterpreter::TryBuildPacket(std::shared_ptr &in) { + if (!TryReadPacketHeader(in)) return false; + + size_t size_needed = curr_input_packet_.extended_ + ? curr_input_packet_.len_ + - curr_input_packet_.buf_->BytesAvailable() + : curr_input_packet_.len_; + if (!in->HasMore(size_needed)) return false; + + // copy bytes only if the packet is longer than the read buffer, + // otherwise we can use the read buffer to save space + if (curr_input_packet_.extended_) + curr_input_packet_.buf_->FillBufferFrom(*in, size_needed); + return true; +} + +bool PostgresProtocolInterpreter::TryReadPacketHeader(std::shared_ptr &in) { + if (curr_input_packet_.header_parsed_) return true; + + // Header format: 1 byte message type (only if non-startup) + // + 4 byte message size (inclusive of these 4 bytes) + size_t header_size = startup_ ? sizeof(int32_t) : 1 + sizeof(int32_t); + // Make sure the entire header is readable + if (!in->HasMore(header_size)) return false; + + // The header is ready to be read, fill in fields accordingly + if (!startup_) + curr_input_packet_.msg_type_ = in->ReadRawValue(); + curr_input_packet_.len_ = in->ReadInt() - sizeof(uint32_t); + + // Extend the buffer as needed + if (curr_input_packet_.len_ > in->Capacity()) { + LOG_INFO("Extended Buffer size required for packet of size %ld", + curr_input_packet_.len_); + // Allocate a larger buffer and copy bytes off from the I/O layer's buffer + curr_input_packet_.buf_ = + std::make_shared(curr_input_packet_.len_); + curr_input_packet_.extended_ = true; + } else { + curr_input_packet_.buf_ = in; + } + + curr_input_packet_.header_parsed_ = true; + return true; +} + +std::shared_ptr PostgresProtocolInterpreter::PacketToCommand() { + if (startup_) return MAKE_COMMAND(StartupCommand); + switch (curr_input_packet_.msg_type_) { + case NetworkMessageType::SIMPLE_QUERY_COMMAND: + return MAKE_COMMAND(SimpleQueryCommand); + case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); + case NetworkMessageType::BIND_COMMAND:return MAKE_COMMAND(BindCommand); + case NetworkMessageType::DESCRIBE_COMMAND: + return MAKE_COMMAND(DescribeCommand); + case NetworkMessageType::EXECUTE_COMMAND:return MAKE_COMMAND(ExecuteCommand); + case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); + case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); + case NetworkMessageType::TERMINATE_COMMAND: + return MAKE_COMMAND(TerminateCommand); + case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); + default: + throw NetworkProcessException("Unexpected Packet Type: " + + std::to_string(static_cast(curr_input_packet_.msg_type_))); + } +} + +} // namespace network +} // namespace peloton \ No newline at end of file diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp new file mode 100644 index 00000000000..403b0fafd62 --- /dev/null +++ b/src/traffic_cop/tcop.cpp @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// t=cop.h +// +// Identification: src/include/traffic_cop/tcop.h +// +// Copyright (c) 2015-18, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "traffic_cop/tcop.h" + +namespace peloton { +namespace tcop { + +} // namespace tcop +} // namespace peloton \ No newline at end of file From 5e9baf4f5cdf8a91d332714fb0a5094be8feada8 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Mon, 25 Jun 2018 23:12:10 -0400 Subject: [PATCH 15/48] Partially write some marshalling code --- src/include/network/network_io_utils.h | 15 +- .../network/postgres_network_commands.h | 129 +++++++++++ .../network/postgres_protocol_interpreter.h | 12 + src/include/traffic_cop/tcop.h | 10 +- src/network/postgres_network_commands.cpp | 213 +++++++++++++++++- src/network/postgres_protocol_handler.cpp | 2 +- src/network/postgres_protocol_interpreter.cpp | 20 +- 7 files changed, 381 insertions(+), 20 deletions(-) diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index c3703c58f0b..b1afef54661 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -140,6 +140,7 @@ class ReadBuffer : public Buffer { */ inline size_t BytesAvailable() { return size_ - offset_; } + /** * Read the given number of bytes into destination, advancing cursor by that * number. It is up to the caller to ensure that there are enough bytes @@ -163,7 +164,7 @@ class ReadBuffer : public Buffer { * @return value of integer switched from network byte order */ template - inline T ReadInt() { + inline T ReadValue() { // We only want to allow for certain type sizes to be used // After the static assert, the compiler should be smart enough to throw // away the other cases and only leave the relevant return statement. @@ -201,6 +202,16 @@ class ReadBuffer : public Buffer { throw NetworkProcessException("Expected nil in read buffer, none found"); } + /** + * Read a not nul-terminated string off the read buffer of specified length + * @return string at head of read buffer + */ + inline std::string ReadString(size_t len) { + std::string result(buf_.begin() + offset_, buf_.begin() + offset_ + len); + offset_ += len; + return result; + } + /** * Read a value of type T off of the buffer, advancing cursor by appropriate * amount. Does NOT convert from network bytes order. It is the caller's @@ -382,8 +393,8 @@ class WriteQueue { } private: + friend class PostgresPacketWriter; std::vector> buffers_; - size_t offset_ = 0; bool flush_ = false; }; diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 4e1f8f65a1d..56dbc8fed58 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #pragma once #include +#include "type/value_factory.h" #include "common/internal_types.h" #include "common/logger.h" #include "common/macros.h" @@ -47,6 +48,134 @@ class PostgresNetworkCommand { explicit PostgresNetworkCommand(std::shared_ptr in, bool flush) : in_(std::move(in)), flush_on_complete_(flush) {} + inline std::vector ReadParamTypes() { + std::vector result; + auto num_params = in_->ReadValue(); + for (uint16_t i = 0; i < num_params; i++) + result.push_back(in_->ReadValue()); + return result; + } + + inline std::vector ReadParamFormats() { + std::vector result; + auto num_formats = in_->ReadValue(); + for (uint16_t i = 0; i < num_formats; i++) + result.push_back(in_->ReadValue()); + return result; + } + + // Why are bind parameter and param values different? + void ReadParamValues(std::vector> &bind_parameters, + std::vector ¶m_values, + const std::vector ¶m_types, + const std::vector &formats) { + auto num_params = in_->ReadValue(); + for (uint16_t i = 0; i < num_params; i++) { + auto param_len = in_->ReadValue(); + if (param_len == -1) { + // NULL + auto peloton_type = PostgresValueTypeToPelotonValueType(param_types[i]); + bind_parameters.push_back(std::make_pair(peloton_type, + std::string(""))); + param_values.push_back(type::ValueFactory::GetNullValueByType( + peloton_type)); + } else { + (formats[i] == 0) + ? ProcessTextParamValue(bind_parameters, + param_values, + param_types[i], + param_len) + : ProcessBinaryParamValue(bind_parameters, + param_values, + param_types[i], + param_len); + } + } + } + + void ProcessTextParamValue(std::vector> &bind_parameters, + std::vector ¶m_values, + PostgresValueType type, + int32_t len) { + std::string val = in_->ReadString((size_t) len); + bind_parameters.push_back(std::make_pair(type::TypeId::VARCHAR, val)); + param_values.push_back( + PostgresValueTypeToPelotonValueType(type) == type::TypeId::VARCHAR + ? type::ValueFactory::GetVarcharValue(val) + : type::ValueFactory::GetVarcharValue(val).CastAs( + PostgresValueTypeToPelotonValueType(type))); + } + + void ProcessBinaryParamValue(std::vector> &bind_parameters, + std::vector ¶m_values, + PostgresValueType type, + int32_t len) { + switch (type) { + case PostgresValueType::TINYINT: { + PELOTON_ASSERT(len == sizeof(int8_t)); + auto val = in_->ReadValue(); + bind_parameters.push_back( + std::make_pair(type::TypeId::TINYINT, std::to_string(val))); + param_values.push_back( + type::ValueFactory::GetTinyIntValue(val).Copy()); + break; + } + case PostgresValueType::SMALLINT: { + PELOTON_ASSERT(len == sizeof(int16_t)); + auto int_val = in_->ReadValue(); + bind_parameters.push_back( + std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val))); + param_values.push_back( + type::ValueFactory::GetSmallIntValue(int_val).Copy()); + break; + } + case PostgresValueType::INTEGER: { + PELOTON_ASSERT(len == sizeof(int32_t)); + auto val = in_->ReadValue(); + bind_parameters.push_back( + std::make_pair(type::TypeId::INTEGER, std::to_string(val))); + param_values.push_back( + type::ValueFactory::GetIntegerValue(val).Copy()); + break; + } + case PostgresValueType::BIGINT: { + PELOTON_ASSERT(len == sizeof(int64_t)); + auto val = in_->ReadValue(); + bind_parameters.push_back( + std::make_pair(type::TypeId::BIGINT, std::to_string(val))); + param_values.push_back( + type::ValueFactory::GetBigIntValue(val).Copy()); + break; + } + case PostgresValueType::DOUBLE: { + PELOTON_ASSERT(len == sizeof(double)); + auto val = in_->ReadValue(); + bind_parameters.push_back( + std::make_pair(type::TypeId::DECIMAL, std::to_string(val))); + param_values.push_back( + type::ValueFactory::GetDecimalValue(val).Copy()); + break; + } + case PostgresValueType::VARBINARY: { + auto val = in_->ReadString((size_t) len); + bind_parameters.push_back( + std::make_pair(type::TypeId::VARBINARY, val)); + param_values.push_back( + type::ValueFactory::GetVarbinaryValue( + reinterpret_cast(val.c_str()), + len, + true)); + break; + } + default: + throw NetworkProcessException("Binary Postgres protocol does not support data type " + + PostgresValueTypeToString(type)); + } + } + std::shared_ptr in_; private: bool flush_on_complete_; diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 6d811f15055..c335590983c 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -43,6 +43,18 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { inline tcop::ClientProcessState &ClientProcessState() { return state_; } + // TODO(Tianyu): What the hell does this thing do? + void CompleteCommand(const QueryType &query_type, int rows, PostgresPacketWriter &out); + + // TODO(Tianyu): Remove these later. Legacy shit code. + void ExecQueryMessageGetResult(ResultType status); + ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); + bool HardcodedExecuteFilter(QueryType query_type); + NetworkProtocolType protocol_type_; + std::vector result_format_; + bool skipped_stmt_ = false; + std::string skipped_query_string_; + QueryType skipped_query_type_; private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index ef21d0c978f..0e288d7039e 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -15,6 +15,7 @@ #include "network/connection_handle.h" #include "parser/postgresparser.h" #include "parser/sql_statement.h" +#include "common/statement_cache.h" namespace peloton { namespace tcop { @@ -29,7 +30,6 @@ struct ClientProcessState { bool is_queuing_; std::string error_message_, db_name_ = DEFAULT_DB_NAME; std::vector param_values_; - std::vector results_; // This save currnet statement in the traffic cop std::shared_ptr statement_; // Default database name @@ -41,12 +41,14 @@ struct ClientProcessState { std::vector result_; std::stack tcop_txn_state_; executor::ExecutionResult p_status_; + StatementCache statement_cache_; }; // Execute a statement ResultType ExecuteStatement( + ClientProcessState &state, const std::shared_ptr &statement, - const std::vector ¶ms, const bool unnamed, + const std::vector ¶ms, bool unnamed, std::shared_ptr param_stats, const std::vector &result_format, std::vector &result, size_t thread_id = 0); @@ -59,13 +61,15 @@ executor::ExecutionResult ExecuteHelper( // Prepare a statement using the parse tree std::shared_ptr PrepareStatement( + ClientProcessState &state, const std::string &statement_name, const std::string &query_string, std::unique_ptr sql_stmt_list, size_t thread_id = 0); bool BindParamsForCachePlan( + ClientProcessState &state, const std::vector> &, - const size_t thread_id = 0); + size_t thread_id = 0); std::vector GenerateTupleDescriptor( parser::SQLStatement *select_stmt); diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 529c2297259..3cff8dd87cc 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -20,11 +20,17 @@ namespace peloton { namespace network { + +// TODO(Tianyu): This is a refactor in progress. +// A lot of the code here should really be moved to traffic cop, and a lot of +// the code here can honestly just be deleted. This is going to be a larger +// project though, so I want to do the architectural refactor first. Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, callback_func, size_t) { - auto proto_version = in_->ReadInt(); + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + auto proto_version = in_->ReadValue(); LOG_INFO("protocol version: %d", proto_version); // SSL initialization if (proto_version == SSL_MESSAGE_VERNO) { @@ -50,7 +56,7 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, std::string key = in_->ReadString(), value = in_->ReadString(); LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); if (key == std::string("database")) - interpreter.ClientProcessState().db_name_ = value; + state.db_name_ = value; interpreter.AddCmdlineOption(key, std::move(value)); } @@ -64,6 +70,7 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, callback_func callback, size_t tid) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); std::unique_ptr sql_stmt_list; @@ -73,18 +80,19 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, // When the query is empty(such as ";" or ";;", still valid), // the pare tree is empty, parser will return nullptr. - if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) { throw ParserException("Error Parsing SQL statement"); } } // If the statement is invalid or not supported yet catch (Exception &e) { - tcop::ProcessInvalidStatement(interpreter.ClientProcessState()); - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); + tcop::ProcessInvalidStatement(state); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + e.what()}}); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - if (sql_stmt_list.get() == nullptr || + if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) { out.WriteEmptyQueryResponse(); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); @@ -101,8 +109,199 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, QueryType query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); - interpreter.ClientProcessState().protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; + + switch (query_type) { + case QueryType::QUERY_PREPARE: { + std::shared_ptr statement(nullptr); + auto prep_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = prep_stmt->name; + statement = tcop::PrepareStatement(state, + stmt_name, + query, + std::move(prep_stmt->query)); + if (statement == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + state.statement_cache_.AddStatement(statement); + + interpreter.CompleteCommand(query_type, 0, out); + + // PAVLO: 2017-01-15 + // There used to be code here that would invoke this method passing + // in NetworkMessageType::READY_FOR_QUERY as the argument. But when + // I switched to strong types, this obviously doesn't work. So I + // switched it to be NetworkTransactionStateType::IDLE. I don't know + // we just don't always send back the internal txn state? + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXECUTE: { + std::vector param_values; + auto *exec_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = exec_stmt->name; + + auto cached_statement = state + .statement_cache_.GetStatement(stmt_name); + if (cached_statement != nullptr) { + state.statement_ = std::move(cached_statement); + } else { + out.WriteErrorResponse( + {{NetworkMessageType::HUMAN_READABLE_ERROR, + "The prepared statement does not exist"}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + interpreter.result_format_ = + std::vector(state + .statement_->GetTupleDescriptor().size(), 0); + + if (!tcop::BindParamsForCachePlan(state, + exec_stmt->parameters)) { + tcop::ProcessInvalidStatement(state); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + auto status = tcop::ExecuteStatement(state, + state.statement_, + state.param_values_, false, + nullptr, interpreter.result_format_, state.result_, tid); + if (state.is_queuing_) return Transition::PROCEED; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXPLAIN: { + auto status = interpreter.ExecQueryExplain(query, + static_cast(*sql_stmt)); + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + default: { + std::string stmt_name = "unamed"; + std::unique_ptr unnamed_sql_stmt_list( + new parser::SQLStatementList()); + unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); + state.statement_ = std::move(tcop::PrepareStatement(state, stmt_name, + query, std::move(unnamed_sql_stmt_list), tid)); + if (state.statement_ == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.param_values_ = std::vector(); + interpreter.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), 0); + auto status = tcop::ExecuteStatement(state, + state.statement_, + state.param_values_, false, nullptr, + interpreter.result_format_, state.result_, tid); + if (state.is_queuing_) return Transition::PROCEED; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + } +} + +Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + callback_func callback, + size_t tid) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + std::string statement_name = in_->ReadString(), query = in_->ReadString(); + + // In JDBC, one query starts with parsing stage. + // Reset skipped_stmt_ to false for the new query. + + interpreter.skipped_stmt_ = false; + std::unique_ptr sql_stmt_list; + QueryType query_type = QueryType::QUERY_OTHER; + try { + LOG_TRACE("%s, %s", statement_name.c_str(), query.c_str()); + auto &peloton_parser = parser::PostgresParser::GetInstance(); + sql_stmt_list = peloton_parser.BuildParseTree(query); + if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { + throw ParserException("Error parsing SQL statement"); + } + } catch (Exception &e) { + tcop::ProcessInvalidStatement(state); + interpreter.skipped_stmt_ = true; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); + return Transition::PROCEED; + } + + // If the query is not supported yet, + // we will skip the rest commands (B,E,..) for this query + // For empty query, we still want to get it constructed + // TODO (Tianyi) Consider handle more statement + bool empty = (sql_stmt_list == nullptr || + sql_stmt_list->GetNumStatements() == 0); + if (!empty) { + parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); + query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); + } + bool skip = !interpreter.HardcodedExecuteFilter(query_type); + if (skip) { + interpreter.skipped_stmt_ = true; + interpreter.skipped_query_string_ = query; + interpreter.skipped_query_type_ = query_type; + out.BeginPacket(NetworkMessageType::PARSE_COMPLETE).EndPacket(); + return Transition::PROCEED; + } + + // Prepare statement + std::shared_ptr statement(nullptr); + + statement = tcop::PrepareStatement(state, statement_name, query, + std::move(sql_stmt_list)); + if (statement == nullptr) { + tcop::ProcessInvalidStatement(state); + interpreter.skipped_stmt_ = true; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + return Transition::PROCEED; + } + LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), + query.c_str()); + auto num_params = in_->ReadValue(); + // Read param types + std::vector param_types = ReadParamTypes(); + + // Cache the received query + bool unnamed_query = statement_name.empty(); + statement->SetParamTypes(param_types); + + // Stat + if (static_cast(settings::SettingsManager::GetInt( + settings::SettingId::stats_mode)) != StatsType::INVALID) { + // Make a copy of param types for stat collection + stats::QueryMetric::QueryParamBuf query_type_buf; + query_type_buf.len = type_buf_len; + query_type_buf.buf = PacketCopyBytes(type_buf_begin, type_buf_len); + + // Unnamed statement + if (unnamed_query) { + unnamed_stmt_param_types_ = query_type_buf; + } else { + statement_param_types_[statement_name] = query_type_buf; + } + } + + // Cache the statement + statement_cache_.AddStatement(statement); + // Send Parse complete response + std::unique_ptr response(new OutputPacket()); + response->msg_type = NetworkMessageType::PARSE_COMPLETE; + responses_.push_back(std::move(response)); } } // namespace network diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index a4dde2a4468..b2d889ac709 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -934,7 +934,7 @@ bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, // get packet size from the header // extract packet contents size // content lengths should exclude the length bytes - rpkt.len = rbuf.ReadInt() - sizeof(uint32_t); + rpkt.len = rbuf.ReadValue() - sizeof(uint32_t); // do we need to use the extended buffer for this packet? rpkt.is_extended = (rpkt.len > rbuf.Capacity()); diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index f036334bb46..402fa644685 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -58,7 +58,7 @@ bool PostgresProtocolInterpreter::TryReadPacketHeader(std::shared_ptrReadRawValue(); - curr_input_packet_.len_ = in->ReadInt() - sizeof(uint32_t); + curr_input_packet_.len_ = in->ReadValue() - sizeof(uint32_t); // Extend the buffer as needed if (curr_input_packet_.len_ > in->Capacity()) { @@ -81,16 +81,22 @@ std::shared_ptr PostgresProtocolInterpreter::PacketToCom switch (curr_input_packet_.msg_type_) { case NetworkMessageType::SIMPLE_QUERY_COMMAND: return MAKE_COMMAND(SimpleQueryCommand); - case NetworkMessageType::PARSE_COMMAND:return MAKE_COMMAND(ParseCommand); - case NetworkMessageType::BIND_COMMAND:return MAKE_COMMAND(BindCommand); + case NetworkMessageType::PARSE_COMMAND: + return MAKE_COMMAND(ParseCommand); + case NetworkMessageType::BIND_COMMAND + :return MAKE_COMMAND(BindCommand); case NetworkMessageType::DESCRIBE_COMMAND: return MAKE_COMMAND(DescribeCommand); - case NetworkMessageType::EXECUTE_COMMAND:return MAKE_COMMAND(ExecuteCommand); - case NetworkMessageType::SYNC_COMMAND:return MAKE_COMMAND(SyncCommand); - case NetworkMessageType::CLOSE_COMMAND:return MAKE_COMMAND(CloseCommand); + case NetworkMessageType::EXECUTE_COMMAND: + return MAKE_COMMAND(ExecuteCommand); + case NetworkMessageType::SYNC_COMMAND + :return MAKE_COMMAND(SyncCommand); + case NetworkMessageType::CLOSE_COMMAND: + return MAKE_COMMAND(CloseCommand); case NetworkMessageType::TERMINATE_COMMAND: return MAKE_COMMAND(TerminateCommand); - case NetworkMessageType::NULL_COMMAND:return MAKE_COMMAND(NullCommand); + case NetworkMessageType::NULL_COMMAND: + return MAKE_COMMAND(NullCommand); default: throw NetworkProcessException("Unexpected Packet Type: " + std::to_string(static_cast(curr_input_packet_.msg_type_))); From dac459777c3d873ce8075f47176a2425bd98bd42 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Mon, 25 Jun 2018 23:54:30 -0400 Subject: [PATCH 16/48] remove hacky EXPLAIN and PREPARE/EXEC --- src/network/postgres_network_commands.cpp | 124 ++++------------------ 1 file changed, 22 insertions(+), 102 deletions(-) diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 3cff8dd87cc..83001303ad3 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -105,110 +105,30 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, // one packet contains only one query. But when using the pipeline mode in // Libpqxx, it sends multiple query in one packet. In this case, it's // incorrect. - auto sql_stmt = sql_stmt_list->PassOutStatement(0); - - QueryType query_type = - StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); - interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; - - switch (query_type) { - case QueryType::QUERY_PREPARE: { - std::shared_ptr statement(nullptr); - auto prep_stmt = dynamic_cast(sql_stmt.get()); - std::string stmt_name = prep_stmt->name; - statement = tcop::PrepareStatement(state, - stmt_name, - query, - std::move(prep_stmt->query)); - if (statement == nullptr) { - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - state.error_message_}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - state.statement_cache_.AddStatement(statement); - - interpreter.CompleteCommand(query_type, 0, out); - - // PAVLO: 2017-01-15 - // There used to be code here that would invoke this method passing - // in NetworkMessageType::READY_FOR_QUERY as the argument. But when - // I switched to strong types, this obviously doesn't work. So I - // switched it to be NetworkTransactionStateType::IDLE. I don't know - // we just don't always send back the internal txn state? - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - }; - case QueryType::QUERY_EXECUTE: { - std::vector param_values; - auto *exec_stmt = dynamic_cast(sql_stmt.get()); - std::string stmt_name = exec_stmt->name; - - auto cached_statement = state - .statement_cache_.GetStatement(stmt_name); - if (cached_statement != nullptr) { - state.statement_ = std::move(cached_statement); - } else { - out.WriteErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, - "The prepared statement does not exist"}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - - interpreter.result_format_ = - std::vector(state - .statement_->GetTupleDescriptor().size(), 0); - - if (!tcop::BindParamsForCachePlan(state, - exec_stmt->parameters)) { - tcop::ProcessInvalidStatement(state); - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - state.error_message_}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - - auto status = tcop::ExecuteStatement(state, - state.statement_, - state.param_values_, false, - nullptr, interpreter.result_format_, state.result_, tid); - if (state.is_queuing_) return Transition::PROCEED; - interpreter.ExecQueryMessageGetResult(status); - return Transition::PROCEED; - }; - case QueryType::QUERY_EXPLAIN: { - auto status = interpreter.ExecQueryExplain(query, - static_cast(*sql_stmt)); - interpreter.ExecQueryMessageGetResult(status); - return Transition::PROCEED; - } - default: { - std::string stmt_name = "unamed"; - std::unique_ptr unnamed_sql_stmt_list( - new parser::SQLStatementList()); - unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); - state.statement_ = std::move(tcop::PrepareStatement(state, stmt_name, - query, std::move(unnamed_sql_stmt_list), tid)); - if (state.statement_ == nullptr) { - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - state.error_message_}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - state.param_values_ = std::vector(); - interpreter.result_format_ = - std::vector(state.statement_->GetTupleDescriptor().size(), 0); - auto status = tcop::ExecuteStatement(state, - state.statement_, - state.param_values_, false, nullptr, - interpreter.result_format_, state.result_, tid); - if (state.is_queuing_) return Transition::PROCEED; - interpreter.ExecQueryMessageGetResult(status); - return Transition::PROCEED; - } + auto sql_stmt = sql_stmt_list->PassOutStatement(0); + std::string stmt_name = "unamed"; + std::unique_ptr unnamed_sql_stmt_list( + new parser::SQLStatementList()); + unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); + state.statement_ = std::move(tcop::PrepareStatement(state, stmt_name, + query, std::move(unnamed_sql_stmt_list), tid)); + if (state.statement_ == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; } + state.param_values_ = std::vector(); + interpreter.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), 0); + auto status = tcop::ExecuteStatement(state, + state.statement_, + state.param_values_, false, nullptr, + interpreter.result_format_, state.result_, tid); + if (state.is_queuing_) return Transition::PROCEED; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; } Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, From 7a82bc1dd7662b964433f3d83f224221112a6efd Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 26 Jun 2018 03:54:06 -0400 Subject: [PATCH 17/48] parser and statement cache push down --- src/include/traffic_cop/tcop.h | 10 +- src/network/postgres_network_commands.cpp | 131 +++------------------- src/traffic_cop/tcop.cpp | 78 +++++++++++++ 3 files changed, 98 insertions(+), 121 deletions(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 0e288d7039e..ed23ff8bd69 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -59,12 +59,10 @@ executor::ExecutionResult ExecuteHelper( const std::vector ¶ms, std::vector &result, const std::vector &result_format, size_t thread_id = 0); -// Prepare a statement using the parse tree -std::shared_ptr PrepareStatement( - ClientProcessState &state, - const std::string &statement_name, const std::string &query_string, - std::unique_ptr sql_stmt_list, - size_t thread_id = 0); +// Prepare a statement +bool PrepareStatement( + ClientProcessState &state, const std::string &query_string, + const std::string &statement_name = "unnamed"); bool BindParamsForCachePlan( ClientProcessState &state, diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 83001303ad3..e2cdd961085 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -73,61 +73,27 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); - std::unique_ptr sql_stmt_list; - try { - auto &peloton_parser = parser::PostgresParser::GetInstance(); - sql_stmt_list = peloton_parser.BuildParseTree(query); - - // When the query is empty(such as ";" or ";;", still valid), - // the pare tree is empty, parser will return nullptr. - if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) { - throw ParserException("Error Parsing SQL statement"); - } - } // If the statement is invalid or not supported yet - catch (Exception &e) { - tcop::ProcessInvalidStatement(state); - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - e.what()}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - - if (sql_stmt_list == nullptr || - sql_stmt_list->GetNumStatements() == 0) { - out.WriteEmptyQueryResponse(); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); - return Transition::PROCEED; - } - - // TODO(Yuchen): Hack. We only process the first statement in the packet now. - // We should store the rest of statements that will not be processed right - // away. For the hack, in most cases, it works. Because for example in psql, - // one packet contains only one query. But when using the pipeline mode in - // Libpqxx, it sends multiple query in one packet. In this case, it's - // incorrect. - - auto sql_stmt = sql_stmt_list->PassOutStatement(0); - std::string stmt_name = "unamed"; - std::unique_ptr unnamed_sql_stmt_list( - new parser::SQLStatementList()); - unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); - state.statement_ = std::move(tcop::PrepareStatement(state, stmt_name, - query, std::move(unnamed_sql_stmt_list), tid)); - if (state.statement_ == nullptr) { + if(!tcop::PrepareStatement(state, query)) { out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } + state.param_values_ = std::vector(); + interpreter.result_format_ = std::vector(state.statement_->GetTupleDescriptor().size(), 0); + auto status = tcop::ExecuteStatement(state, state.statement_, state.param_values_, false, nullptr, interpreter.result_format_, state.result_, tid); - if (state.is_queuing_) return Transition::PROCEED; + + if (state.is_queuing_) return Transition::NEED_RESULT; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; } @@ -137,91 +103,26 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, size_t tid) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string statement_name = in_->ReadString(), query = in_->ReadString(); + LOG_TRACE("Execute query: %s", query.c_str()); - // In JDBC, one query starts with parsing stage. - // Reset skipped_stmt_ to false for the new query. - - interpreter.skipped_stmt_ = false; - std::unique_ptr sql_stmt_list; - QueryType query_type = QueryType::QUERY_OTHER; - try { - LOG_TRACE("%s, %s", statement_name.c_str(), query.c_str()); - auto &peloton_parser = parser::PostgresParser::GetInstance(); - sql_stmt_list = peloton_parser.BuildParseTree(query); - if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { - throw ParserException("Error parsing SQL statement"); - } - } catch (Exception &e) { - tcop::ProcessInvalidStatement(state); - interpreter.skipped_stmt_ = true; - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); - return Transition::PROCEED; - } - - // If the query is not supported yet, - // we will skip the rest commands (B,E,..) for this query - // For empty query, we still want to get it constructed - // TODO (Tianyi) Consider handle more statement - bool empty = (sql_stmt_list == nullptr || - sql_stmt_list->GetNumStatements() == 0); - if (!empty) { - parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); - query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); - } - bool skip = !interpreter.HardcodedExecuteFilter(query_type); - if (skip) { - interpreter.skipped_stmt_ = true; - interpreter.skipped_query_string_ = query; - interpreter.skipped_query_type_ = query_type; - out.BeginPacket(NetworkMessageType::PARSE_COMPLETE).EndPacket(); - return Transition::PROCEED; - } - - // Prepare statement - std::shared_ptr statement(nullptr); - - statement = tcop::PrepareStatement(state, statement_name, query, - std::move(sql_stmt_list)); - if (statement == nullptr) { - tcop::ProcessInvalidStatement(state); - interpreter.skipped_stmt_ = true; + if(!tcop::PrepareStatement(state, query, statement_name)) { out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } + LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), query.c_str()); - auto num_params = in_->ReadValue(); + // Read param types std::vector param_types = ReadParamTypes(); - - // Cache the received query - bool unnamed_query = statement_name.empty(); - statement->SetParamTypes(param_types); - - // Stat - if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { - // Make a copy of param types for stat collection - stats::QueryMetric::QueryParamBuf query_type_buf; - query_type_buf.len = type_buf_len; - query_type_buf.buf = PacketCopyBytes(type_buf_begin, type_buf_len); - - // Unnamed statement - if (unnamed_query) { - unnamed_stmt_param_types_ = query_type_buf; - } else { - statement_param_types_[statement_name] = query_type_buf; - } - } - - // Cache the statement - statement_cache_.AddStatement(statement); + state.statement_->SetParamTypes(param_types); // Send Parse complete response std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::PARSE_COMPLETE; - responses_.push_back(std::move(response)); + out.BeginPacket(NetworkMessageType::PARSE_COMPLETE); + return Transition::PROCEED; } } // namespace network diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index 403b0fafd62..b4a63270b59 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -10,10 +10,88 @@ // //===----------------------------------------------------------------------===// +#include "planner/plan_util.h" +#include "binder/bind_node_visitor.h" #include "traffic_cop/tcop.h" namespace peloton { namespace tcop { +// Prepare a statement +bool tcop::PrepareStatement( + ClientProcessState &state, const std::string &query_string, + const std::string &statement_name) { + try { + // TODO(Tianyi) Implicitly start a txn + + // parse the query + auto &peloton_parser = parser::PostgresParser::GetInstance(); + auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); + + // When the query is empty(such as ";" or ";;", still valid), + // the parse tree is empty, parser will return nullptr. + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) { + throw ParserException("Error Parsing SQL statement"); + } + + // TODO(Yuchen): Hack. We only process the first statement in the packet now. + // We should store the rest of statements that will not be processed right + // away. For the hack, in most cases, it works. Because for example in psql, + // one packet contains only one query. But when using the pipeline mode in + // Libpqxx, it sends multiple query in one packet. In this case, it's + // incorrect. + StatementType stmt_type = sql_stmt_list->GetStatement(0)->GetType(); + QueryType query_type = + StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); + + std::shared_ptr statement = std::make_shared( + statement_name, query_type, query_string, std::move(sql_stmt_list)); + + // Empty statement edge case + if (sql_stmt_list == nullptr || + sql_stmt_list->GetNumStatements() == 0) { + std::shared_ptr statement = + std::make_shared(statement_name, QueryType::QUERY_INVALID, + query_string, std::move(sql_stmt_list)); + state.statement_cache_.AddStatement(statement); + return true; + } + + // Run binder + auto bind_node_visitor = binder::BindNodeVisitor( + tcop_txn_state_.top().first, state.db_name_); + bind_node_visitor.BindNameToNode( + statement->GetStmtParseTreeList()->GetStatement(0)); + auto plan = state.optimizer_->BuildPelotonPlanTree( + statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); + statement->SetPlanTree(plan); + // Get the tables that our plan references so that we know how to + // invalidate it at a later point when the catalog changes + const std::set table_oids = + planner::PlanUtil::GetTablesReferenced(plan.get()); + statement->SetReferencedTables(table_oids); + + if (query_type == QueryType::QUERY_SELECT) { + auto tuple_descriptor = GenerateTupleDescriptor( + statement->GetStmtParseTreeList()->GetStatement(0)); + statement->SetTupleDescriptor(tuple_descriptor); + LOG_TRACE("select query, finish setting"); + } + + state.statement_cache_.AddStatement(statement); + + } // If the statement is invalid or not supported yet + catch (Exception &e) { + // TODO implicit end the txn here + state.error_message_ = e.what(); + return false; + } + + return true; +} + + + + } // namespace tcop } // namespace peloton \ No newline at end of file From 8307f5f0cf655380a8eb63e72497a521e41e75ad Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 26 Jun 2018 06:01:11 -0400 Subject: [PATCH 18/48] add in execution helper --- src/common/internal_types.cpp | 3 - src/include/traffic_cop/tcop.h | 19 +++--- src/network/postgres_network_commands.cpp | 7 +-- src/network/postgres_protocol_handler.cpp | 2 +- src/traffic_cop/tcop.cpp | 76 +++++++++++++++++++++++ 5 files changed, 88 insertions(+), 19 deletions(-) diff --git a/src/common/internal_types.cpp b/src/common/internal_types.cpp index 855f7ef2d9b..b6e52105ae6 100644 --- a/src/common/internal_types.cpp +++ b/src/common/internal_types.cpp @@ -2030,9 +2030,6 @@ std::string ResultTypeToString(ResultType type) { case ResultType::UNKNOWN: { return ("UNKNOWN"); } - case ResultType::QUEUING: { - return ("QUEUING"); - } case ResultType::TO_ABORT: { return ("TO_ABORT"); } diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index ed23ff8bd69..d8a909fa61a 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -32,32 +32,29 @@ struct ClientProcessState { std::vector param_values_; // This save currnet statement in the traffic cop std::shared_ptr statement_; - // Default database name - int rows_affected_; // The optimizer used for this connection std::unique_ptr optimizer_; // flag of single statement txn - bool single_statement_txn_; std::vector result_; - std::stack tcop_txn_state_; executor::ExecutionResult p_status_; StatementCache statement_cache_; + + // The current callback to be invoked after execution completes. + void (*task_callback_)(void *); + void *task_callback_arg_; }; // Execute a statement ResultType ExecuteStatement( ClientProcessState &state, - const std::shared_ptr &statement, - const std::vector ¶ms, bool unnamed, - std::shared_ptr param_stats, - const std::vector &result_format, std::vector &result, - size_t thread_id = 0); + const std::vector &result_format, std::vector &result); // Helper to handle txn-specifics for the plan-tree of a statement. -executor::ExecutionResult ExecuteHelper( +void ExecuteHelper( + ClientProcessState &state, std::shared_ptr plan, const std::vector ¶ms, std::vector &result, - const std::vector &result_format, size_t thread_id = 0); + const std::vector &result_format); // Prepare a statement bool PrepareStatement( diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index e2cdd961085..7df873fa16b 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -86,11 +86,9 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, std::vector(state.statement_->GetTupleDescriptor().size(), 0); auto status = tcop::ExecuteStatement(state, - state.statement_, - state.param_values_, false, nullptr, - interpreter.result_format_, state.result_, tid); + interpreter.result_format_, state.result_); - if (state.is_queuing_) return Transition::NEED_RESULT; + if (status == ResultType::QUEUING) return Transition::NEED_RESULT; interpreter.ExecQueryMessageGetResult(status); @@ -125,5 +123,6 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } + } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp index b2d889ac709..a76833609b2 100644 --- a/src/network/postgres_protocol_handler.cpp +++ b/src/network/postgres_protocol_handler.cpp @@ -260,7 +260,7 @@ ResultType PostgresProtocolHandler::ExecQueryExplain( unnamed_sql_stmt_list->PassInStatement(std::move(explain_stmt.real_sql_stmt)); auto stmt = traffic_cop_->PrepareStatement("explain", query, std::move(unnamed_sql_stmt_list)); - ResultType status = ResultType::UNKNOWN; + ResultType status; if (stmt != nullptr) { traffic_cop_->SetStatement(stmt); std::vector plan_info = StringUtil::Split( diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index b4a63270b59..a73b93955a4 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "threadpool/mono_queue_pool.h" #include "planner/plan_util.h" #include "binder/bind_node_visitor.h" #include "traffic_cop/tcop.h" @@ -90,8 +91,83 @@ bool tcop::PrepareStatement( return true; } +ResultType tcop::ExecuteStatement( + ClientProcessState &state, + const std::vector &result_format, std::vector &result) { + + LOG_TRACE("Execute Statement of name: %s", + state.statement_->GetStatementName().c_str()); + LOG_TRACE("Execute Statement of query: %s", + state.statement_->GetQueryString().c_str()); + LOG_TRACE("Execute Statement Plan:\n%s", + planner::PlanUtil::GetInfo(state.statement_->GetPlanTree().get()).c_str()); + LOG_TRACE("Execute Statement Query Type: %s", + state.statement_->GetQueryTypeString().c_str()); + LOG_TRACE("----QueryType: %d--------", + static_cast(state.statement_->GetQueryType())); + try { + switch (state.statement_->GetQueryType()) { + case QueryType::QUERY_BEGIN: { + return BeginQueryHelper(state.thread_id_); + } + case QueryType::QUERY_COMMIT: { + return CommitQueryHelper(); + } + case QueryType::QUERY_ROLLBACK: { + return AbortQueryHelper(); + } + default: { + // The statement may be out of date + // It needs to be replan + if (state.statement_->GetNeedsReplan()) { + // TODO(Tianyi) Move Statement Replan into Statement's method + // to increase coherence + auto bind_node_visitor = binder::BindNodeVisitor( + tcop_txn_state_.top().first, state.db_name_); + bind_node_visitor.BindNameToNode( + state.statement_->GetStmtParseTreeList()->GetStatement(0)); + auto plan = state.optimizer_->BuildPelotonPlanTree( + state.statement_->GetStmtParseTreeList(), tcop_txn_state_.top().first); + state.statement_->SetPlanTree(plan); + state.statement_->SetNeedsReplan(true); + } + + ExecuteHelper(state, result, result_format); + return ResultType::QUEUING; + } + } + } catch (Exception &e) { + state.error_message_ = e.what(); + return ResultType::FAILURE; + } +} + +void tcop::ExecuteHelper( + ClientProcessState &state, + std::vector &result, + const std::vector &result_format) { + auto plan = state.statement_->GetPlanTree(); + auto params = state.param_values_, + + auto on_complete = [&result, &state](executor::ExecutionResult p_status, + std::vector &&values) { + state.p_status_ = p_status; + state.error_message_ = std::move(p_status.m_error_message); + result = std::move(values); + state.task_callback_(state.task_callback_arg_); + }; + + auto &pool = threadpool::MonoQueuePool::GetInstance(); + pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete] { + executor::PlanExecutor::ExecutePlan(plan, txn, params, result_format, + on_complete); + }); + + LOG_TRACE("Check Tcop_txn_state Size After ExecuteHelper %lu", + tcop_txn_state_.size()); +} } // namespace tcop } // namespace peloton \ No newline at end of file From f2dd0b509d6c163fa55c27aff134254174ca8aaa Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 26 Jun 2018 10:52:14 -0400 Subject: [PATCH 19/48] Implement txn handle wrapper class --- .../traffic_cop/client_transaction_handle.h | 221 ++++++++++++++++++ src/traffic_cop/client_transaction_handle.cpp | 107 +++++++++ 2 files changed, 328 insertions(+) create mode 100644 src/include/traffic_cop/client_transaction_handle.h create mode 100644 src/traffic_cop/client_transaction_handle.cpp diff --git a/src/include/traffic_cop/client_transaction_handle.h b/src/include/traffic_cop/client_transaction_handle.h new file mode 100644 index 00000000000..42eb39ef300 --- /dev/null +++ b/src/include/traffic_cop/client_transaction_handle.h @@ -0,0 +1,221 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// client_transaction_handle.h +// +// Identification: src/include/traffic_cop/client_transaction_handle.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "concurrency/transaction_manager_factory.h" +namespace peloton { +namespace tcop { + +using TxnContext = concurrency::TransactionContext; +using TxnManagerFactory = concurrency::TransactionManagerFactory; + +enum class TransactionState{ + IDLE = 0, + STARTED, + FAILING, + ABORTING, +}; + +class ClientTxnHandle; + +/** + * abtract class to provide a unified interface of txn handling + */ +class AbtractClientTxnHandler { + public: + /** + * @brief Start a txn if there is no txn at the moment this function is called. + * @param thread_id number to generate epoch id in a distributed manner + * @param handle Client transaction context + * @return Current trancation that is started + * @throw TransactionException when no txn can be started (eg. current txn is failed already) + */ + virtual TxnContext *ImplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle) = 0; + + /** + * @brief Force starting a txn + * @param thread_id number to generate epoch id in a distributed manner + * @param handle Client transaction context + * @return Current trancation that is started + * @throw TransactionException when no txn can be started (eg. there is already an txn) + */ + virtual TxnContext *ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle) = 0; + + /** + * @brief Implicitly end a txn + * @param handle + * @param handle Client transaction context + */ + virtual void End(ClientTxnHandle &handle) = 0; + + /** + * @brief Explicitly commit a txn + * @param handle Client transaction context + * @throw TransactionException when there is no txn started + */ + virtual bool Commit(ClientTxnHandle &handle) = 0; + + /** + * @brief Explicitly abort a txn + * @param handle Client transaction context + */ + virtual void Abort(ClientTxnHandle &handle) = 0; + +}; + +/** + * Client Transaction handler for Transaction Handler when in Single-Statement Mode + */ +class SingleStmtClientTxnHandler : AbtractClientTxnHandler{ + + /** + * @see AbstractClientTxnHandler + */ + TxnContext *ImplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle); + + /** + * @brief This function should never be called in this mode + */ + inline TxnContext *ExplicitBegin(const size_t, ClientTxnHandle &) { + throw TransactionException("Should not be called"); + } + + /** + * @see AbstractClientTxnHandler + */ + void End(ClientTxnHandle &handle); + + /** + * @brief This function should never be called in this mode + */ + inline bool Commit(ClientTxnHandle &handle) { + throw TransactionException("Should not be called"); + } + + /** + * @see AbstractClientTxnHandler + */ + void Abort(ClientTxnHandle &handle); + +}; + +class MultiStmtsClientTxnHandler : AbtractClientTxnHandler { + + /** + * @see AbstractClientTxnHandler + */ + TxnContext *ImplicitBegin(const size_t, ClientTxnHandle &handle_); + + /** + * @see AbstractClientTxnHandler + */ + TxnContext *ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle & handle); + + /** + * @see AbstractClientTxnHandler + */ + inline void End(ClientTxnHandle &handle) {} + + /** + * @see AbstractClientTxnHandler + */ + bool Commit(ClientTxnHandle &handle); + + /** + * @see AbstractClientTxnHandler + */ + void Abort(ClientTxnHandle &handle); +}; + +/** + * @brief Wrapper class that could provide functions to properly start and end a transaction. + * + * It would operate in either Single-Statement or Multi-Statements mode, using different handler + * + */ +class ClientTxnHandle { + friend class AbtractClientTxnHandler; + friend class SingleStmtClientTxnHandler; + friend class MultiStmtsClientTxnHandler; + + public: + + /** + * Start a transaction if there is no transaction + * @param thread_id number to generate epoch id in a distributed manner + * @return transaction context + */ + TxnContext *ImplicitBegin(const size_t thread_id = 0); + + /** + * Force starting a transaction if there is no transaction + * @param thread_id number to generate epoch id in a distributed manner + * @return transaction context + */ + TxnContext *ExplicitBegin(const size_t thread_id = 0); + + /** + * Commit/Abort a transaction and do the necessary cleanup + */ + void ImplicitEnd(); + + /** + * Explicitly commit a transaction + * @return if the commit is successful + */ + bool ExplicitCommit(); + + /** + * Explicitly abort a transaction + */ + void ExplicitAbort(); + + /** + * @brief Getter function of txn state + * @return current trancation state + */ + inline TransactionState GetTxnState() { + return txn_state_; + } + + /** + * @brief Getter function of current transaction context + * @return current transaction context + */ + inline TxnContext *GetTxn() { + return txn_; + } + + private: + + TransactionState txn_state_; + + TxnContext *txn_; + + bool single_stmt_handler_ = true; + + std::unique_ptr handler_; + + inline void ChangeToSingleStmtHandler() { + handler_ = std::unique_ptr(new SingleStmtClientTxnHandler()); + single_stmt_handler_ = true; + } + + inline void ChangeToMultiStmtsHandler() { + handler_ = std::unique_ptr(new MultiStmtsClientTxnHandler()); + single_stmt_handler_ = false; + } +}; + +} +} \ No newline at end of file diff --git a/src/traffic_cop/client_transaction_handle.cpp b/src/traffic_cop/client_transaction_handle.cpp new file mode 100644 index 00000000000..61513455cc8 --- /dev/null +++ b/src/traffic_cop/client_transaction_handle.cpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// client_transaction_handle.cpp +// +// Identification: src/traffic_cop/client_transaction_handle.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "traffic_cop/client_transaction_handle.h" +#include +namespace peloton { +namespace tcop { + + /* Function implementations of SingleStmtClientTxnHandler */ + TxnContext *SingleStmtClientTxnHandler::ImplicitBegin(const size_t thread_id, ClientTxnHandle &handle) { + switch (handle.txn_state_) { + case TransactionState::IDLE: { + handle.txn_ = TxnManagerFactory::GetInstance().BeginTransaction(thread_id); + handle.txn_state_ =TransactionState::STARTED; + } + case TransactionState::STARTED: + case TransactionState::FAILING: + case TransactionState::ABORTING: + break; + } + return handle.txn_; + } + + void SingleStmtClientTxnHandler::End(ClientTxnHandle &handle) { + // TODO Implement this function + } + + void SingleStmtClientTxnHandler::Abort(ClientTxnHandle &handle) { + // TODO Implement this function + } + + + /* Function implementations of MultiStmtsClientTxnHandler */ + TxnContext *MultiStmtsClientTxnHandler::ImplicitBegin(const size_t, ClientTxnHandle &handle_) { + return handle_.GetTxn(); + } + + TxnContext *MultiStmtsClientTxnHandler::ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle & handle){ + switch (handle.txn_state_) { + case TransactionState::IDLE: { + handle.txn_ = TxnManagerFactory::GetInstance().BeginTransaction(thread_id); + handle.txn_state_ = TransactionState::STARTED; + } + case TransactionState::STARTED: + TxnManagerFactory::GetInstance().AbortTransaction(handle.txn_); + handle.txn_state_ = TransactionState::ABORTING; + throw TransactionException("Current Transaction started already"); + case TransactionState::FAILING: + case TransactionState::ABORTING: + break; + } + return handle.txn_; + } + + bool MultiStmtsClientTxnHandler::Commit(ClientTxnHandle &handler) { + // TODO implement this function + return false; + } + + void MultiStmtsClientTxnHandler::Abort(ClientTxnHandle &handler) { + // TODO implement this function + } + + /* Function implementations of ClientTxnHandle */ + TxnContext *ClientTxnHandle::ImplicitBegin(const size_t thread_id) { + return handler_->ImplicitBegin(thread_id, *this); + } + + TxnContext *ClientTxnHandle::ExplicitBegin(const size_t thread_id) { + if (single_stmt_handler_) { + ChangeToMultiStmtsHandler(); + } + return handler_->ExplicitBegin(thread_id, *this); + } + + void ClientTxnHandle::ImplicitEnd() { + handler_->End(*this); + if (txn_state_ == TransactionState::IDLE && !single_stmt_handler_) { + ChangeToSingleStmtHandler(); + } + } + + void ClientTxnHandle::ExplicitAbort() { + handler_->Abort(*this); + if (!single_stmt_handler_) + ChangeToSingleStmtHandler(); + } + + bool ClientTxnHandle::ExplicitCommit() { + bool success = handler_->Commit(*this); + if (success && !single_stmt_handler_) { + ChangeToSingleStmtHandler(); + } + return success; + } + +} +} \ No newline at end of file From 7d55f681f4e5eccdcffedcfb085de5bb4e4dcb6b Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 26 Jun 2018 11:13:48 -0400 Subject: [PATCH 20/48] Add use cases in comments --- src/include/traffic_cop/client_transaction_handle.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/include/traffic_cop/client_transaction_handle.h b/src/include/traffic_cop/client_transaction_handle.h index 42eb39ef300..2079494b37d 100644 --- a/src/include/traffic_cop/client_transaction_handle.h +++ b/src/include/traffic_cop/client_transaction_handle.h @@ -142,6 +142,15 @@ class MultiStmtsClientTxnHandler : AbtractClientTxnHandler { * * It would operate in either Single-Statement or Multi-Statements mode, using different handler * + * Serveral general patterns of function calls is : + * 1. + * ImplicitBegin()[By Prepare] -> ImplicitBegin() [By Execute]-> ImplicitEnd(); + * 2. + * ImplicitBegin() -> ExplicitBegin() -> ImplictBegin()[second query] + * ->ImplicitEnd()[second query] -> ExplicitCommit() -> ImplicitEnd(); + * 3. + * ImplicitBegin() -> ExplicitBegin() -> ImplictBegin()[second query] + * ->ImplicitENd()[second query]-> ExplicitAbort() -> ImplicitEnd(); */ class ClientTxnHandle { friend class AbtractClientTxnHandler; From f91d0c070ff5da03eadc83cf50ebab142d8721f8 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 26 Jun 2018 11:14:17 -0400 Subject: [PATCH 21/48] Add txn state handle into tcop --- src/include/traffic_cop/tcop.h | 17 +++++++---------- src/traffic_cop/tcop.cpp | 28 ++++++++++++++++++---------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index d8a909fa61a..69c958ce377 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,11 +11,11 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include "network/connection_handle.h" #include "parser/postgresparser.h" #include "parser/sql_statement.h" #include "common/statement_cache.h" +#include "client_transaction_handle.h" namespace peloton { namespace tcop { @@ -39,6 +39,9 @@ struct ClientProcessState { executor::ExecutionResult p_status_; StatementCache statement_cache_; + // Transaction Handling Wrapper + ClientTxnHandle txn_handle_; + // The current callback to be invoked after execution completes. void (*task_callback_)(void *); void *task_callback_arg_; @@ -52,9 +55,9 @@ ResultType ExecuteStatement( // Helper to handle txn-specifics for the plan-tree of a statement. void ExecuteHelper( ClientProcessState &state, - std::shared_ptr plan, - const std::vector ¶ms, std::vector &result, - const std::vector &result_format); + std::vector &result, + const std::vector &result_format, + concurrency::TransactionContext *txn); // Prepare a statement bool PrepareStatement( @@ -72,18 +75,12 @@ std::vector GenerateTupleDescriptor( FieldInfo GetColumnFieldForValueType(std::string column_name, type::TypeId column_type); -ResultType CommitQueryHelper(); - void ExecuteStatementPlanGetResult(); ResultType ExecuteStatementGetResult(); void ProcessInvalidStatement(ClientProcessState &state); -ResultType BeginQueryHelper(size_t thread_id); - -ResultType AbortQueryHelper(); - // Get all data tables from a TableRef. // For multi-way join // still a HACK diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index a73b93955a4..ee9586d150d 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -23,8 +23,7 @@ bool tcop::PrepareStatement( ClientProcessState &state, const std::string &query_string, const std::string &statement_name) { try { - // TODO(Tianyi) Implicitly start a txn - + state.txn_handle_.ImplicitBegin(state.thread_id_); // parse the query auto &peloton_parser = parser::PostgresParser::GetInstance(); auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); @@ -87,6 +86,7 @@ bool tcop::PrepareStatement( state.error_message_ = e.what(); return false; } + // TODO catch txn exception return true; } @@ -109,31 +109,38 @@ ResultType tcop::ExecuteStatement( try { switch (state.statement_->GetQueryType()) { case QueryType::QUERY_BEGIN: { - return BeginQueryHelper(state.thread_id_); + state.txn_handle_.ExplicitBegin(state.thread_id_); + return ResultType::SUCCESS; } case QueryType::QUERY_COMMIT: { - return CommitQueryHelper(); + if (!state.txn_handle_.ExplicitCommit()) { + // TODO Check which result type we should return + return ResultType::FAILURE; + } + return ResultType::SUCCESS; } case QueryType::QUERY_ROLLBACK: { - return AbortQueryHelper(); + state.txn_handle_.ExplicitAbort(); + return ResultType::SUCCESS; } default: { // The statement may be out of date // It needs to be replan + auto txn = state.txn_handle_.ImplicitBegin(state.thread_id_); if (state.statement_->GetNeedsReplan()) { // TODO(Tianyi) Move Statement Replan into Statement's method // to increase coherence auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, state.db_name_); + txn, state.db_name_); bind_node_visitor.BindNameToNode( state.statement_->GetStmtParseTreeList()->GetStatement(0)); auto plan = state.optimizer_->BuildPelotonPlanTree( - state.statement_->GetStmtParseTreeList(), tcop_txn_state_.top().first); + state.statement_->GetStmtParseTreeList(), txn); state.statement_->SetPlanTree(plan); state.statement_->SetNeedsReplan(true); } - ExecuteHelper(state, result, result_format); + ExecuteHelper(state, result, result_format, txn); return ResultType::QUEUING; } } @@ -147,9 +154,10 @@ ResultType tcop::ExecuteStatement( void tcop::ExecuteHelper( ClientProcessState &state, std::vector &result, - const std::vector &result_format) { + const std::vector &result_format, + concurrency::TransactionContext *txn) { auto plan = state.statement_->GetPlanTree(); - auto params = state.param_values_, + auto params = state.param_values_; auto on_complete = [&result, &state](executor::ExecutionResult p_status, std::vector &&values) { From 43e59e055c7af6daa7a2097a9140a92da966c400 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 26 Jun 2018 13:49:11 -0400 Subject: [PATCH 22/48] Put types in order --- src/common/statement.cpp | 4 +- src/include/common/internal_types.h | 2 + src/include/common/statement.h | 10 +-- .../network/postgres_network_commands.h | 13 ++-- .../network/postgres_protocol_interpreter.h | 19 ++--- src/include/network/protocol_interpreter.h | 1 - src/include/traffic_cop/tcop.h | 17 +++-- src/network/postgres_network_commands.cpp | 22 ++---- src/traffic_cop/tcop.cpp | 70 ++++++++++--------- src/traffic_cop/traffic_cop.cpp | 1 - 10 files changed, 70 insertions(+), 89 deletions(-) diff --git a/src/common/statement.cpp b/src/common/statement.cpp index c4285852ced..5d7f7652a04 100644 --- a/src/common/statement.cpp +++ b/src/common/statement.cpp @@ -70,11 +70,11 @@ std::string Statement::GetQueryTypeString() const { return query_type_string_; } QueryType Statement::GetQueryType() const { return query_type_; } -void Statement::SetParamTypes(const std::vector& param_types) { +void Statement::SetParamTypes(const std::vector& param_types) { param_types_ = param_types; } -std::vector Statement::GetParamTypes() const { return param_types_; } +std::vector Statement::GetParamTypes() const { return param_types_; } void Statement::SetTupleDescriptor( const std::vector& tuple_descriptor) { diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 449bff3e373..1d642c051c6 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1449,6 +1449,8 @@ enum class SSLLevel { SSL_VERIIFY = 2, }; +using callback_func = std::function; + // Eigen/Matrix types used in brain // TODO(saatvik): Generalize Eigen utilities across all types typedef std::vector> matrix_t; diff --git a/src/include/common/statement.h b/src/include/common/statement.h index ff9b4620c87..7d4a3eabfdc 100644 --- a/src/include/common/statement.h +++ b/src/include/common/statement.h @@ -65,13 +65,13 @@ class Statement : public Printable { QueryType GetQueryType() const; - void SetParamTypes(const std::vector ¶m_types); + void SetParamTypes(const std::vector ¶m_types); - std::vector GetParamTypes() const; + std::vector GetParamTypes() const; void SetTupleDescriptor(const std::vector &tuple_descriptor); - void SetReferencedTables(const std::set table_ids); + void SetReferencedTables(std::set table_ids); const std::set GetReferencedTables() const; @@ -79,7 +79,7 @@ class Statement : public Printable { const std::shared_ptr &GetPlanTree() const; - std::unique_ptr const &GetStmtParseTreeList() { + const std::unique_ptr &GetStmtParseTreeList() { return sql_stmt_list_; } @@ -113,7 +113,7 @@ class Statement : public Printable { std::string query_type_string_; // format codes of the parameters - std::vector param_types_; + std::vector param_types_; // schema of result tuple std::vector tuple_descriptor_; diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 56dbc8fed58..67a5551f285 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -27,7 +27,7 @@ class name : public PostgresNetworkCommand { \ : PostgresNetworkCommand(std::move(in), flush) {} \ virtual Transition Exec(PostgresProtocolInterpreter &, \ PostgresPacketWriter &, \ - callback_func, size_t) override; \ + callback_func) override; \ } namespace peloton { @@ -39,8 +39,7 @@ class PostgresNetworkCommand { public: virtual Transition Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func callback, - size_t thread_id) = 0; + callback_func callback) = 0; inline bool FlushOnComplete() { return flush_on_complete_; } @@ -65,8 +64,7 @@ class PostgresNetworkCommand { } // Why are bind parameter and param values different? - void ReadParamValues(std::vector> &bind_parameters, + void ReadParamValues(std::vector> &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, const std::vector &formats) { @@ -100,7 +98,7 @@ class PostgresNetworkCommand { PostgresValueType type, int32_t len) { std::string val = in_->ReadString((size_t) len); - bind_parameters.push_back(std::make_pair(type::TypeId::VARCHAR, val)); + bind_parameters.emplace_back(type::TypeId::VARCHAR, val); param_values.push_back( PostgresValueTypeToPelotonValueType(type) == type::TypeId::VARCHAR ? type::ValueFactory::GetVarcharValue(val) @@ -117,8 +115,7 @@ class PostgresNetworkCommand { case PostgresValueType::TINYINT: { PELOTON_ASSERT(len == sizeof(int8_t)); auto val = in_->ReadValue(); - bind_parameters.push_back( - std::make_pair(type::TypeId::TINYINT, std::to_string(val))); + bind_parameters.emplace_back(type::TypeId::TINYINT, std::to_string(val)); param_values.push_back( type::ValueFactory::GetTinyIntValue(val).Copy()); break; diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index c335590983c..a1e661251ad 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -16,10 +16,6 @@ #include "network/postgres_network_commands.h" #include "traffic_cop/tcop.h" -#define MAKE_COMMAND(type) \ - std::static_pointer_cast( \ - std::make_shared(std::move(curr_input_packet_.buf_))) - namespace peloton { namespace network { @@ -27,7 +23,7 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { public: // TODO(Tianyu): Is this even the right thread id? It seems that all the // concurrency code is dependent on this number. - PostgresProtocolInterpreter(size_t thread_id) = default; + explicit PostgresProtocolInterpreter(size_t thread_id) = default; Transition Process(std::shared_ptr in, std::shared_ptr out, @@ -43,18 +39,11 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { inline tcop::ClientProcessState &ClientProcessState() { return state_; } - // TODO(Tianyu): What the hell does this thing do? - void CompleteCommand(const QueryType &query_type, int rows, PostgresPacketWriter &out); - - // TODO(Tianyu): Remove these later. Legacy shit code. + // TODO(Tianyu): WTF is this? void ExecQueryMessageGetResult(ResultType status); - ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); - bool HardcodedExecuteFilter(QueryType query_type); - NetworkProtocolType protocol_type_; + void ExecExecuteMessageGetResult(ResultType status); + std::vector result_format_; - bool skipped_stmt_ = false; - std::string skipped_query_string_; - QueryType skipped_query_type_; private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 98cdeeaa046..3d100375d5b 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -17,7 +17,6 @@ namespace peloton { namespace network { -using callback_func = std::function; class ProtocolInterpreter { public: diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 69c958ce377..b3abbacab6c 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,7 +11,8 @@ //===----------------------------------------------------------------------===// #pragma once -#include "network/connection_handle.h" +#include "executor/plan_executor.h" +#include "optimizer/abstract_optimizer.h" #include "parser/postgresparser.h" #include "parser/sql_statement.h" #include "common/statement_cache.h" @@ -27,7 +28,7 @@ using TcopTxnState = std::pair; // TODO(Tianyu): We can probably get rid of a bunch of fields from here struct ClientProcessState { size_t thread_id_; - bool is_queuing_; + bool is_queuing_ = false; std::string error_message_, db_name_ = DEFAULT_DB_NAME; std::vector param_values_; // This save currnet statement in the traffic cop @@ -38,26 +39,24 @@ struct ClientProcessState { std::vector result_; executor::ExecutionResult p_status_; StatementCache statement_cache_; - // Transaction Handling Wrapper ClientTxnHandle txn_handle_; - - // The current callback to be invoked after execution completes. - void (*task_callback_)(void *); - void *task_callback_arg_; }; // Execute a statement ResultType ExecuteStatement( ClientProcessState &state, - const std::vector &result_format, std::vector &result); + const std::vector &result_format, + std::vector &result, + callback_func callback); // Helper to handle txn-specifics for the plan-tree of a statement. void ExecuteHelper( ClientProcessState &state, std::vector &result, const std::vector &result_format, - concurrency::TransactionContext *txn); + concurrency::TransactionContext *txn, + callback_func callback); // Prepare a statement bool PrepareStatement( diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 7df873fa16b..c0c77082a85 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -27,8 +27,7 @@ namespace network { // project though, so I want to do the architectural refactor first. Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func, - size_t) { + callback_func) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); auto proto_version = in_->ReadValue(); LOG_INFO("protocol version: %d", proto_version); @@ -68,8 +67,7 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func callback, - size_t tid) { + callback_func callback) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); @@ -79,14 +77,11 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - state.param_values_ = std::vector(); - interpreter.result_format_ = std::vector(state.statement_->GetTupleDescriptor().size(), 0); - auto status = tcop::ExecuteStatement(state, - interpreter.result_format_, state.result_); + interpreter.result_format_, state.result_, callback); if (status == ResultType::QUEUING) return Transition::NEED_RESULT; @@ -97,8 +92,7 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func callback, - size_t tid) { + callback_func) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string statement_name = in_->ReadString(), query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); @@ -109,20 +103,16 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), query.c_str()); // Read param types - std::vector param_types = ReadParamTypes(); - state.statement_->SetParamTypes(param_types); + state.statement_->SetParamTypes(ReadParamTypes()); // Send Parse complete response - std::unique_ptr response(new OutputPacket()); - out.BeginPacket(NetworkMessageType::PARSE_COMPLETE); + out.BeginPacket(NetworkMessageType::PARSE_COMPLETE).EndPacket(); return Transition::PROCEED; } - } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index ee9586d150d..fccf1687769 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -2,7 +2,7 @@ // // Peloton // -// t=cop.h +// tcop.h // // Identification: src/include/traffic_cop/tcop.h // @@ -19,9 +19,9 @@ namespace peloton { namespace tcop { // Prepare a statement -bool tcop::PrepareStatement( - ClientProcessState &state, const std::string &query_string, - const std::string &statement_name) { +bool tcop::PrepareStatement(ClientProcessState &state, + const std::string &query_string, + const std::string &statement_name) { try { state.txn_handle_.ImplicitBegin(state.thread_id_); // parse the query @@ -30,9 +30,8 @@ bool tcop::PrepareStatement( // When the query is empty(such as ";" or ";;", still valid), // the parse tree is empty, parser will return nullptr. - if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) { + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) throw ParserException("Error Parsing SQL statement"); - } // TODO(Yuchen): Hack. We only process the first statement in the packet now. // We should store the rest of statements that will not be processed right @@ -44,26 +43,30 @@ bool tcop::PrepareStatement( QueryType query_type = StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); - std::shared_ptr statement = std::make_shared( - statement_name, query_type, query_string, std::move(sql_stmt_list)); + auto statement = std::make_shared(statement_name, + query_type, + query_string, + std::move(sql_stmt_list)); // Empty statement edge case - if (sql_stmt_list == nullptr || - sql_stmt_list->GetNumStatements() == 0) { - std::shared_ptr statement = - std::make_shared(statement_name, QueryType::QUERY_INVALID, - query_string, std::move(sql_stmt_list)); - state.statement_cache_.AddStatement(statement); + if (statement->GetStmtParseTreeList() == nullptr || + statement->GetStmtParseTreeList()->GetNumStatements() == 0) { + state.statement_cache_.AddStatement( + std::make_shared(statement_name, + QueryType::QUERY_INVALID, + query_string, + std::move(statement->PassStmtParseTreeList()))); return true; } // Run binder - auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, state.db_name_); + auto bind_node_visitor = binder::BindNodeVisitor(state.txn_handle_.GetTxn(), + state.db_name_); bind_node_visitor.BindNameToNode( statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = state.optimizer_->BuildPelotonPlanTree( - statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); + auto plan = state.optimizer_-> + BuildPelotonPlanTree(statement->GetStmtParseTreeList(), + state.txn_handle_.GetTxn()); statement->SetPlanTree(plan); // Get the tables that our plan references so that we know how to // invalidate it at a later point when the catalog changes @@ -80,20 +83,20 @@ bool tcop::PrepareStatement( state.statement_cache_.AddStatement(statement); - } // If the statement is invalid or not supported yet - catch (Exception &e) { - // TODO implicit end the txn here + } catch (Exception &e) { + // TODO(Tianyi) implicit end the txn here state.error_message_ = e.what(); return false; } - // TODO catch txn exception - + // TODO(Tianyi) catch txn exception return true; } ResultType tcop::ExecuteStatement( ClientProcessState &state, - const std::vector &result_format, std::vector &result) { + const std::vector &result_format, + std::vector &result, + const callback_func &callback) { LOG_TRACE("Execute Statement of name: %s", state.statement_->GetStatementName().c_str()); @@ -140,11 +143,10 @@ ResultType tcop::ExecuteStatement( state.statement_->SetNeedsReplan(true); } - ExecuteHelper(state, result, result_format, txn); + ExecuteHelper(state, result, result_format, txn, callback); return ResultType::QUEUING; } } - } catch (Exception &e) { state.error_message_ = e.what(); return ResultType::FAILURE; @@ -155,21 +157,25 @@ void tcop::ExecuteHelper( ClientProcessState &state, std::vector &result, const std::vector &result_format, - concurrency::TransactionContext *txn) { + concurrency::TransactionContext *txn, + const callback_func &callback) { auto plan = state.statement_->GetPlanTree(); auto params = state.param_values_; - auto on_complete = [&result, &state](executor::ExecutionResult p_status, - std::vector &&values) { + auto on_complete = [callback, &](executor::ExecutionResult p_status, + std::vector &&values) { state.p_status_ = p_status; state.error_message_ = std::move(p_status.m_error_message); result = std::move(values); - state.task_callback_(state.task_callback_arg_); + callback(); }; auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete] { - executor::PlanExecutor::ExecutePlan(plan, txn, params, result_format, + pool.SubmitTask([txn, on_complete, &] { + executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), + txn, + state.param_values_, + result_format, on_complete); }); diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 7bfffebb4c0..a8794d9cfb0 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -607,7 +607,6 @@ ResultType TrafficCop::ExecuteStatement( return ExecuteStatementGetResult(); } } - } catch (Exception &e) { error_message_ = e.what(); return ResultType::FAILURE; From 87536fc0e6e28d9ca9d2e82f18f4836c72d08f88 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Wed, 27 Jun 2018 04:46:59 -0400 Subject: [PATCH 23/48] Refactor function signature of ExecuteStatement() --- src/include/common/internal_types.h | 8 +++---- src/include/traffic_cop/tcop.h | 36 ++++++++++++++++++++++------- src/traffic_cop/tcop.cpp | 26 ++++++++------------- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 1d642c051c6..9b5547d0198 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -862,10 +862,10 @@ enum class ResultType { SUCCESS = 1, FAILURE = 2, ABORTED = 3, // aborted - NOOP = 4, // no op - UNKNOWN = 5, - QUEUING = 6, - TO_ABORT = 7, + NOOP = 4, // no op // TODO Remove this type + UNKNOWN = 5, // TODO Remove this type + QUEUING = 6, // TODO Remove this type + TO_ABORT = 7, // TODO Remove this type }; std::string ResultTypeToString(ResultType type); ResultType StringToResultType(const std::string &str); diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index b3abbacab6c..4cb44bd261b 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -43,14 +43,39 @@ struct ClientProcessState { ClientTxnHandle txn_handle_; }; -// Execute a statement -ResultType ExecuteStatement( +/** + * Build and optimized a statement from raw query string and store it in state + * @param state Client state context + * @param query_string raw query in string format + * @param statement_name Name of the statement stored in the statement cache + * @return true if the statement is created + */ +bool PrepareStatement( + ClientProcessState &state, const std::string &query_string, + const std::string &statement_name = "unnamed"); + +/** + * Execute the statemnet attached in state + * @param state Client state context + * @param result_format expected result's format specified by the protocol + * @param result the vector to store the result being returned + * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false + * @return true if the execution is finished + */ +bool ExecuteStatement( ClientProcessState &state, const std::vector &result_format, std::vector &result, callback_func callback); -// Helper to handle txn-specifics for the plan-tree of a statement. +/** + * Helper function to submit the executable plan to worker pool + * @param state Client state context + * @param result_format expected result's format specified by the protocol + * @param result the vector to store the result being returned + * @param txn Transaction context + * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false + */ void ExecuteHelper( ClientProcessState &state, std::vector &result, @@ -58,11 +83,6 @@ void ExecuteHelper( concurrency::TransactionContext *txn, callback_func callback); -// Prepare a statement -bool PrepareStatement( - ClientProcessState &state, const std::string &query_string, - const std::string &statement_name = "unnamed"); - bool BindParamsForCachePlan( ClientProcessState &state, const std::vector> &, diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index fccf1687769..77def9bfb9b 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -92,7 +92,7 @@ bool tcop::PrepareStatement(ClientProcessState &state, return true; } -ResultType tcop::ExecuteStatement( +bool tcop::ExecuteStatement( ClientProcessState &state, const std::vector &result_format, std::vector &result, @@ -113,18 +113,18 @@ ResultType tcop::ExecuteStatement( switch (state.statement_->GetQueryType()) { case QueryType::QUERY_BEGIN: { state.txn_handle_.ExplicitBegin(state.thread_id_); - return ResultType::SUCCESS; + return true; } case QueryType::QUERY_COMMIT: { if (!state.txn_handle_.ExplicitCommit()) { - // TODO Check which result type we should return - return ResultType::FAILURE; + state.p_status_.m_result = ResultType::FAILURE; + //TODO set error message } - return ResultType::SUCCESS; + return true; } case QueryType::QUERY_ROLLBACK: { state.txn_handle_.ExplicitAbort(); - return ResultType::SUCCESS; + return true; } default: { // The statement may be out of date @@ -133,23 +133,20 @@ ResultType tcop::ExecuteStatement( if (state.statement_->GetNeedsReplan()) { // TODO(Tianyi) Move Statement Replan into Statement's method // to increase coherence - auto bind_node_visitor = binder::BindNodeVisitor( - txn, state.db_name_); - bind_node_visitor.BindNameToNode( - state.statement_->GetStmtParseTreeList()->GetStatement(0)); auto plan = state.optimizer_->BuildPelotonPlanTree( state.statement_->GetStmtParseTreeList(), txn); state.statement_->SetPlanTree(plan); - state.statement_->SetNeedsReplan(true); + state.statement_->SetNeedsReplan(false); } ExecuteHelper(state, result, result_format, txn, callback); - return ResultType::QUEUING; + return false; } } } catch (Exception &e) { + state.p_status_.m_result = ResultType::FAILURE; state.error_message_ = e.what(); - return ResultType::FAILURE; + return true; } } @@ -178,9 +175,6 @@ void tcop::ExecuteHelper( result_format, on_complete); }); - - LOG_TRACE("Check Tcop_txn_state Size After ExecuteHelper %lu", - tcop_txn_state_.size()); } } // namespace tcop From 67d8016bf3acdac97c8684083d22fe65e666fb48 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 27 Jun 2018 13:16:58 -0400 Subject: [PATCH 24/48] minor code clean up --- src/include/common/internal_types.h | 3 +- .../network/postgres_network_commands.h | 126 ++---------------- .../network/postgres_protocol_interpreter.h | 2 +- src/include/network/protocol_interpreter.h | 2 +- src/include/traffic_cop/tcop.h | 4 +- src/network/postgres_network_commands.cpp | 125 ++++++++++++++++- src/network/postgres_protocol_interpreter.cpp | 2 +- src/traffic_cop/tcop.cpp | 4 +- 8 files changed, 141 insertions(+), 127 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 9b5547d0198..70495dc44fd 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1449,7 +1449,8 @@ enum class SSLLevel { SSL_VERIIFY = 2, }; -using callback_func = std::function; +using CallbackFunc = std::function; +using BindParameter = std::pair; // Eigen/Matrix types used in brain // TODO(saatvik): Generalize Eigen utilities across all types diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 67a5551f285..0ae64066e10 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -27,7 +27,7 @@ class name : public PostgresNetworkCommand { \ : PostgresNetworkCommand(std::move(in), flush) {} \ virtual Transition Exec(PostgresProtocolInterpreter &, \ PostgresPacketWriter &, \ - callback_func) override; \ + CallbackFunc) override; \ } namespace peloton { @@ -39,7 +39,7 @@ class PostgresNetworkCommand { public: virtual Transition Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func callback) = 0; + CallbackFunc callback) = 0; inline bool FlushOnComplete() { return flush_on_complete_; } @@ -47,131 +47,25 @@ class PostgresNetworkCommand { explicit PostgresNetworkCommand(std::shared_ptr in, bool flush) : in_(std::move(in)), flush_on_complete_(flush) {} - inline std::vector ReadParamTypes() { - std::vector result; - auto num_params = in_->ReadValue(); - for (uint16_t i = 0; i < num_params; i++) - result.push_back(in_->ReadValue()); - return result; - } + std::vector ReadParamTypes(); - inline std::vector ReadParamFormats() { - std::vector result; - auto num_formats = in_->ReadValue(); - for (uint16_t i = 0; i < num_formats; i++) - result.push_back(in_->ReadValue()); - return result; - } + std::vector ReadParamFormats(); // Why are bind parameter and param values different? - void ReadParamValues(std::vector> &bind_parameters, + void ReadParamValues(std::vector &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, - const std::vector &formats) { - auto num_params = in_->ReadValue(); - for (uint16_t i = 0; i < num_params; i++) { - auto param_len = in_->ReadValue(); - if (param_len == -1) { - // NULL - auto peloton_type = PostgresValueTypeToPelotonValueType(param_types[i]); - bind_parameters.push_back(std::make_pair(peloton_type, - std::string(""))); - param_values.push_back(type::ValueFactory::GetNullValueByType( - peloton_type)); - } else { - (formats[i] == 0) - ? ProcessTextParamValue(bind_parameters, - param_values, - param_types[i], - param_len) - : ProcessBinaryParamValue(bind_parameters, - param_values, - param_types[i], - param_len); - } - } - } + const std::vector &formats); - void ProcessTextParamValue(std::vector> &bind_parameters, + void ProcessTextParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { - std::string val = in_->ReadString((size_t) len); - bind_parameters.emplace_back(type::TypeId::VARCHAR, val); - param_values.push_back( - PostgresValueTypeToPelotonValueType(type) == type::TypeId::VARCHAR - ? type::ValueFactory::GetVarcharValue(val) - : type::ValueFactory::GetVarcharValue(val).CastAs( - PostgresValueTypeToPelotonValueType(type))); - } + int32_t len); - void ProcessBinaryParamValue(std::vector> &bind_parameters, + void ProcessBinaryParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { - switch (type) { - case PostgresValueType::TINYINT: { - PELOTON_ASSERT(len == sizeof(int8_t)); - auto val = in_->ReadValue(); - bind_parameters.emplace_back(type::TypeId::TINYINT, std::to_string(val)); - param_values.push_back( - type::ValueFactory::GetTinyIntValue(val).Copy()); - break; - } - case PostgresValueType::SMALLINT: { - PELOTON_ASSERT(len == sizeof(int16_t)); - auto int_val = in_->ReadValue(); - bind_parameters.push_back( - std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val))); - param_values.push_back( - type::ValueFactory::GetSmallIntValue(int_val).Copy()); - break; - } - case PostgresValueType::INTEGER: { - PELOTON_ASSERT(len == sizeof(int32_t)); - auto val = in_->ReadValue(); - bind_parameters.push_back( - std::make_pair(type::TypeId::INTEGER, std::to_string(val))); - param_values.push_back( - type::ValueFactory::GetIntegerValue(val).Copy()); - break; - } - case PostgresValueType::BIGINT: { - PELOTON_ASSERT(len == sizeof(int64_t)); - auto val = in_->ReadValue(); - bind_parameters.push_back( - std::make_pair(type::TypeId::BIGINT, std::to_string(val))); - param_values.push_back( - type::ValueFactory::GetBigIntValue(val).Copy()); - break; - } - case PostgresValueType::DOUBLE: { - PELOTON_ASSERT(len == sizeof(double)); - auto val = in_->ReadValue(); - bind_parameters.push_back( - std::make_pair(type::TypeId::DECIMAL, std::to_string(val))); - param_values.push_back( - type::ValueFactory::GetDecimalValue(val).Copy()); - break; - } - case PostgresValueType::VARBINARY: { - auto val = in_->ReadString((size_t) len); - bind_parameters.push_back( - std::make_pair(type::TypeId::VARBINARY, val)); - param_values.push_back( - type::ValueFactory::GetVarbinaryValue( - reinterpret_cast(val.c_str()), - len, - true)); - break; - } - default: - throw NetworkProcessException("Binary Postgres protocol does not support data type " - + PostgresValueTypeToString(type)); - } - } + int32_t len); std::shared_ptr in_; private: diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index a1e661251ad..3e53318a362 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -27,7 +27,7 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { Transition Process(std::shared_ptr in, std::shared_ptr out, - callback_func callback) override; + CallbackFunc callback) override; inline void GetResult() override {} diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 3d100375d5b..3896f0384ef 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -24,7 +24,7 @@ class ProtocolInterpreter { virtual Transition Process(std::shared_ptr in, std::shared_ptr out, - callback_func callback) = 0; + CallbackFunc callback) = 0; // TODO(Tianyu): Do we really need this crap? virtual void GetResult() = 0; diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 4cb44bd261b..890da84def8 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -66,7 +66,7 @@ bool ExecuteStatement( ClientProcessState &state, const std::vector &result_format, std::vector &result, - callback_func callback); + CallbackFunc callback); /** * Helper function to submit the executable plan to worker pool @@ -81,7 +81,7 @@ void ExecuteHelper( std::vector &result, const std::vector &result_format, concurrency::TransactionContext *txn, - callback_func callback); + CallbackFunc callback); bool BindParamsForCachePlan( ClientProcessState &state, diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index c0c77082a85..b4694751c5d 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -25,9 +25,128 @@ namespace network { // A lot of the code here should really be moved to traffic cop, and a lot of // the code here can honestly just be deleted. This is going to be a larger // project though, so I want to do the architectural refactor first. +std::vector PostgresNetworkCommand::ReadParamTypes() { + std::vector result; + auto num_params = in_->ReadValue(); + for (uint16_t i = 0; i < num_params; i++) + result.push_back(in_->ReadValue()); + return result; +} + +std::vector PostgresNetworkCommand::ReadParamFormats() { + std::vector result; + auto num_formats = in_->ReadValue(); + for (uint16_t i = 0; i < num_formats; i++) + result.push_back(in_->ReadValue()); + return result; +} + +void PostgresNetworkCommand::ReadParamValues(std::vector &bind_parameters, + std::vector ¶m_values, + const std::vector ¶m_types, + const std::vector &formats) { + auto num_params = in_->ReadValue(); + for (uint16_t i = 0; i < num_params; i++) { + auto param_len = in_->ReadValue(); + if (param_len == -1) { + // NULL + auto peloton_type = PostgresValueTypeToPelotonValueType(param_types[i]); + bind_parameters.emplace_back(peloton_type, + std::string("")); + param_values.push_back(type::ValueFactory::GetNullValueByType( + peloton_type)); + } else { + (formats[i] == 0) + ? ProcessTextParamValue(bind_parameters, + param_values, + param_types[i], + param_len) + : ProcessBinaryParamValue(bind_parameters, + param_values, + param_types[i], + param_len); + } + } +} + +void PostgresNetworkCommand::ProcessTextParamValue(std::vector &bind_parameters, + std::vector ¶m_values, + PostgresValueType type, + int32_t len) { + std::string val = in_->ReadString((size_t) len); + bind_parameters.emplace_back(type::TypeId::VARCHAR, val); + param_values.push_back( + PostgresValueTypeToPelotonValueType(type) == type::TypeId::VARCHAR + ? type::ValueFactory::GetVarcharValue(val) + : type::ValueFactory::GetVarcharValue(val).CastAs( + PostgresValueTypeToPelotonValueType(type))); +} + +void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector &bind_parameters, + std::vector ¶m_values, + PostgresValueType type, + int32_t len) { + switch (type) { + case PostgresValueType::TINYINT: { + PELOTON_ASSERT(len == sizeof(int8_t)); + auto val = in_->ReadValue(); + bind_parameters.emplace_back(type::TypeId::TINYINT, std::to_string(val)); + param_values.push_back( + type::ValueFactory::GetTinyIntValue(val).Copy()); + break; + } + case PostgresValueType::SMALLINT: { + PELOTON_ASSERT(len == sizeof(int16_t)); + auto int_val = in_->ReadValue(); + bind_parameters.emplace_back(type::TypeId::SMALLINT, std::to_string(int_val)); + param_values.push_back( + type::ValueFactory::GetSmallIntValue(int_val).Copy()); + break; + } + case PostgresValueType::INTEGER: { + PELOTON_ASSERT(len == sizeof(int32_t)); + auto val = in_->ReadValue(); + bind_parameters.emplace_back(type::TypeId::INTEGER, std::to_string(val)); + param_values.push_back( + type::ValueFactory::GetIntegerValue(val).Copy()); + break; + } + case PostgresValueType::BIGINT: { + PELOTON_ASSERT(len == sizeof(int64_t)); + auto val = in_->ReadValue(); + bind_parameters.emplace_back(type::TypeId::BIGINT, std::to_string(val)); + param_values.push_back( + type::ValueFactory::GetBigIntValue(val).Copy()); + break; + } + case PostgresValueType::DOUBLE: { + PELOTON_ASSERT(len == sizeof(double)); + auto val = in_->ReadValue(); + bind_parameters.emplace_back(type::TypeId::DECIMAL, std::to_string(val)); + param_values.push_back( + type::ValueFactory::GetDecimalValue(val).Copy()); + break; + } + case PostgresValueType::VARBINARY: { + auto val = in_->ReadString((size_t) len); + bind_parameters.emplace_back(type::TypeId::VARBINARY, val); + param_values.push_back( + type::ValueFactory::GetVarbinaryValue( + reinterpret_cast(val.c_str()), + len, + true)); + break; + } + default: + throw NetworkProcessException("Binary Postgres protocol does not support data type " + + PostgresValueTypeToString(type)); + } +} + + Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func) { + CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); auto proto_version = in_->ReadValue(); LOG_INFO("protocol version: %d", proto_version); @@ -67,7 +186,7 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func callback) { + CallbackFunc callback) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); @@ -92,7 +211,7 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, - callback_func) { + CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string statement_name = in_->ReadString(), query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index 402fa644685..bc5be0e3257 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -21,7 +21,7 @@ namespace peloton { namespace network { Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, std::shared_ptr out, - callback_func callback) { + CallbackFunc callback) { if (!TryBuildPacket(in)) return Transition::NEED_READ; std::shared_ptr command = PacketToCommand(); curr_input_packet_.Clear(); diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index 77def9bfb9b..b1cf50cbd89 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -96,7 +96,7 @@ bool tcop::ExecuteStatement( ClientProcessState &state, const std::vector &result_format, std::vector &result, - const callback_func &callback) { + const CallbackFunc &callback) { LOG_TRACE("Execute Statement of name: %s", state.statement_->GetStatementName().c_str()); @@ -155,7 +155,7 @@ void tcop::ExecuteHelper( std::vector &result, const std::vector &result_format, concurrency::TransactionContext *txn, - const callback_func &callback) { + const CallbackFunc &callback) { auto plan = state.statement_->GetPlanTree(); auto params = state.param_values_; From aaa37d1d17aee6e14c87188c769aab3db253b19c Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Thu, 28 Jun 2018 13:50:04 -0400 Subject: [PATCH 25/48] Save work --- src/include/common/internal_types.h | 10 + src/include/common/portal.h | 10 +- src/include/network/connection_handle.h | 6 - src/include/network/network_io_wrappers.h | 3 - .../network/postgres_network_commands.h | 6 +- .../network/postgres_protocol_interpreter.h | 9 +- src/include/network/postgres_protocol_utils.h | 28 +- src/include/traffic_cop/tcop.h | 63 ++-- src/network/connection_handle.cpp | 28 -- src/network/postgres_network_commands.cpp | 355 ++++++++++++++++-- src/network/postgres_protocol_interpreter.cpp | 2 +- src/traffic_cop/tcop.cpp | 56 ++- 12 files changed, 413 insertions(+), 163 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 70495dc44fd..b3df6710502 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1452,6 +1452,16 @@ enum class SSLLevel { using CallbackFunc = std::function; using BindParameter = std::pair; +enum class PostgresDataFormat : int16_t { + TEXT = 0, + BINARY = 1 +}; + +enum class PostgresDescribeType : uchar { + PORTAL = 'P', + STATEMENT = 'S' +}; + // Eigen/Matrix types used in brain // TODO(saatvik): Generalize Eigen utilities across all types typedef std::vector> matrix_t; diff --git a/src/include/common/portal.h b/src/include/common/portal.h index f4017904583..a1b7cee6ca6 100644 --- a/src/include/common/portal.h +++ b/src/include/common/portal.h @@ -31,8 +31,7 @@ class Portal { Portal &operator=(Portal &&) = delete; Portal(const std::string &portal_name, std::shared_ptr statement, - std::vector bind_parameters, - std::shared_ptr param_stat); + std::vector bind_parameters); ~Portal(); @@ -40,10 +39,6 @@ class Portal { const std::vector &GetParameters() const; - inline std::shared_ptr GetParamStat() const { - return param_stat_; - } - // Portal name std::string portal_name_; @@ -52,9 +47,6 @@ class Portal { // Values bound to the statement of this portal std::vector bind_parameters_; - - // The serialized params for stats collection - std::shared_ptr param_stat_; }; } // namespace peloton diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 4d3a6f218b0..5c9f4e4ab5a 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -44,16 +44,10 @@ namespace peloton { namespace network { /** -<<<<<<< HEAD * A ConnectionHandle encapsulates all information we need to do IO about * a client connection for its entire duration. This includes a state machine * and the necessary libevent infrastructure for a handler to work on this * connection. -======= - * @brief A ConnectionHandle encapsulates all information about a client - * connection for its entire duration. This includes a state machine and the - * necessary libevent infrastructure for a handler to work on this connection. ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 */ class ConnectionHandle { public: diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 7ffc4e74122..661002c5979 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -17,10 +17,7 @@ #include #include "common/exception.h" #include "common/utility.h" -<<<<<<< HEAD #include "network/network_types.h" -======= ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 #include "network/marshal.h" namespace peloton { diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 0ae64066e10..917a70b10c2 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -49,13 +49,13 @@ class PostgresNetworkCommand { std::vector ReadParamTypes(); - std::vector ReadParamFormats(); + std::vector ReadParamFormats(); // Why are bind parameter and param values different? void ReadParamValues(std::vector &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, - const std::vector &formats); + const std::vector &formats); void ProcessTextParamValue(std::vector &bind_parameters, std::vector ¶m_values, @@ -67,6 +67,8 @@ class PostgresNetworkCommand { PostgresValueType type, int32_t len); + std::vector ReadResultFormats(size_t tuple_size); + std::shared_ptr in_; private: bool flush_on_complete_; diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 3e53318a362..abd02d2abfa 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -39,11 +39,16 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { inline tcop::ClientProcessState &ClientProcessState() { return state_; } - // TODO(Tianyu): WTF is this? + + // TODO(Tianyu): Remove these later for better responsibility assignment + void CompleteCommand(PostgresPacketWriter &out, const QueryType &query_type, int rows); void ExecQueryMessageGetResult(ResultType status); void ExecExecuteMessageGetResult(ResultType status); + ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); + bool HardcodedExecuteFilter(QueryType query_type); - std::vector result_format_; + NetworkProtocolType protocol_type_; + std::unordered_map> portals_; private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index ab26741420c..91032a6de34 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -60,14 +60,21 @@ class PostgresPacketWriter { } /** - * Write out a packet with a single byte (e.g. SSL_YES or SSL_NO). This is a - * special case since no size field is provided. + * Write out a packet with a single type. Some messages will be + * special cases since no size field is provided. (SSL_YES, SSL_NO) * @param type Type of message to write out */ - inline void WriteSingleBytePacket(NetworkMessageType type) { + inline void WriteSingleTypePacket(NetworkMessageType type) { // Make sure no active packet being constructed PELOTON_ASSERT(curr_packet_len_ == nullptr); - queue_.BufferWriteRawValue(type); + switch (type) { + case NetworkMessageType::SSL_YES: + case NetworkMessageType::SSL_NO: + queue_.BufferWriteRawValue(type); + break; + default: + BeginPacket(type).EndPacket(); + } } /** @@ -121,7 +128,7 @@ class PostgresPacketWriter { } /** - * Append an integer of specified length onto the write queue. (1, 2, 4, or 8 + * Append a value of specified length onto the write queue. (1, 2, 4, or 8 * bytes). It is assumed that these bytes need to be converted to network * byte ordering. * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. @@ -129,7 +136,7 @@ class PostgresPacketWriter { * @return self-reference for chaining */ template - inline PostgresPacketWriter &AppendInt(T val) { + inline PostgresPacketWriter &AppendValue(T val) { // We only want to allow for certain type sizes to be used // After the static assert, the compiler should be smart enough to throw // away the other cases and only leave the relevant return statement. @@ -194,6 +201,15 @@ class PostgresPacketWriter { .EndPacket(); } + inline void WriteTupleDescriptor(const std::vector &tuple_descriptor) { + if (tuple_descriptor.empty()) return; + BeginPacket(NetworkMessageType::ROW_DESCRIPTION) + .AppendValue(tuple_descriptor.size()); + for (auto &col : tuple_descriptor) { + AppendString(std::get<0>(col))() + } + } + /** * End the packet. A packet write must be in progress and said write is not * well-formed until this method is called. diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 890da84def8..e4fa6392745 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include "executor/plan_executor.h" #include "optimizer/abstract_optimizer.h" #include "parser/postgresparser.h" @@ -27,7 +28,7 @@ using TcopTxnState = std::pair; // TODO(Tianyu): Probably need a better name // TODO(Tianyu): We can probably get rid of a bunch of fields from here struct ClientProcessState { - size_t thread_id_; + size_t thread_id_ = 0; bool is_queuing_ = false; std::string error_message_, db_name_ = DEFAULT_DB_NAME; std::vector param_values_; @@ -36,57 +37,47 @@ struct ClientProcessState { // The optimizer used for this connection std::unique_ptr optimizer_; // flag of single statement txn + bool single_statement_txn_ = false; + std::vector result_format_; + // flag of single statement txn std::vector result_; - executor::ExecutionResult p_status_; + std::stack tcop_txn_state_; + NetworkTransactionStateType txn_state_ = NetworkTransactionStateType::INVALID; + bool skipped_stmt_ = false; + std::string skipped_query_string_; + QueryType skipped_query_type_ = QueryType::QUERY_INVALID; StatementCache statement_cache_; - // Transaction Handling Wrapper - ClientTxnHandle txn_handle_; }; -/** - * Build and optimized a statement from raw query string and store it in state - * @param state Client state context - * @param query_string raw query in string format - * @param statement_name Name of the statement stored in the statement cache - * @return true if the statement is created - */ -bool PrepareStatement( - ClientProcessState &state, const std::string &query_string, - const std::string &statement_name = "unnamed"); +inline std::unique_ptr ParseQuery(const std::string &query_string) { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); + // When the query is empty(such as ";" or ";;", still valid), + // the parse tree is empty, parser will return nullptr. + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) + throw ParserException("Error Parsing SQL statement"); + return sql_stmt_list; +} + +std::shared_ptr PrepareStatement(ClientProcessState &state, + const std::string &statement_name, + const std::string &query_string, + std::unique_ptr &&sql_stmt_list); -/** - * Execute the statemnet attached in state - * @param state Client state context - * @param result_format expected result's format specified by the protocol - * @param result the vector to store the result being returned - * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false - * @return true if the execution is finished - */ -bool ExecuteStatement( +ResultType ExecuteStatement( ClientProcessState &state, - const std::vector &result_format, - std::vector &result, + bool unnamed, CallbackFunc callback); -/** - * Helper function to submit the executable plan to worker pool - * @param state Client state context - * @param result_format expected result's format specified by the protocol - * @param result the vector to store the result being returned - * @param txn Transaction context - * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false - */ void ExecuteHelper( ClientProcessState &state, std::vector &result, - const std::vector &result_format, concurrency::TransactionContext *txn, CallbackFunc callback); bool BindParamsForCachePlan( ClientProcessState &state, - const std::vector> &, - size_t thread_id = 0); + const std::vector> &); std::vector GenerateTupleDescriptor( parser::SQLStatement *select_stmt); diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index ac57bb93f20..3f31ec4fea0 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -186,34 +186,6 @@ Transition ConnectionHandle::TrySslHandshake() { io_wrapper_); } -Transition ConnectionHandle::TryCloseConnection() { - LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); - // TODO(Tianyu): Handle close failure - Transition close = io_wrapper_->Close(); - if (close != Transition::PROCEED) return close; - // Remove listening event - // Only after the connection is closed is it safe to remove events, - // after this point no object in the system has reference to this - // connection handle and we will need to destruct and exit. - conn_handler_->UnregisterEvent(network_event_); - conn_handler_->UnregisterEvent(workpool_event_); - // This object is essentially managed by libevent (which unfortunately does - // not accept shared_ptrs.) and thus as we shut down we need to manually - // deallocate this object. - delete this; - return Transition::NONE; -} - -Transition ConnectionHandle::TrySslHandshake() { - // Flush out all the response first - if (HasResponse()) { - auto write_ret = TryWrite(); - if (write_ret != Transition::PROCEED) return write_ret; - } - return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake( - io_wrapper_); -} - Transition ConnectionHandle::TryCloseConnection() { LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); // TODO(Tianyu): Handle close failure diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index b4694751c5d..f4b0cafe70e 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -14,6 +14,7 @@ #include "network/peloton_server.h" #include "network/postgres_network_commands.h" #include "traffic_cop/tcop.h" +#include "settings/settings_manager.h" #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) @@ -25,7 +26,7 @@ namespace network { // A lot of the code here should really be moved to traffic cop, and a lot of // the code here can honestly just be deleted. This is going to be a larger // project though, so I want to do the architectural refactor first. -std::vector PostgresNetworkCommand::ReadParamTypes() { +std::vector PostgresNetworkCommand::ReadParamTypes() { std::vector result; auto num_params = in_->ReadValue(); for (uint16_t i = 0; i < num_params; i++) @@ -33,18 +34,19 @@ std::vector PostgresNetworkCommand::ReadParamTypes() { return result; } -std::vector PostgresNetworkCommand::ReadParamFormats() { - std::vector result; +std::vector PostgresNetworkCommand::ReadParamFormats() { + std::vector result; auto num_formats = in_->ReadValue(); for (uint16_t i = 0; i < num_formats; i++) - result.push_back(in_->ReadValue()); + result.push_back(in_->ReadValue()); return result; } void PostgresNetworkCommand::ReadParamValues(std::vector &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, - const std::vector &formats) { + const std::vector< + PostgresDataFormat> &formats) { auto num_params = in_->ReadValue(); for (uint16_t i = 0; i < num_params; i++) { auto param_len = in_->ReadValue(); @@ -55,24 +57,29 @@ void PostgresNetworkCommand::ReadParamValues(std::vector &bind_pa std::string("")); param_values.push_back(type::ValueFactory::GetNullValueByType( peloton_type)); - } else { - (formats[i] == 0) - ? ProcessTextParamValue(bind_parameters, - param_values, - param_types[i], - param_len) - : ProcessBinaryParamValue(bind_parameters, + } else + switch (formats[i]) { + case PostgresDataFormat::TEXT: + ProcessTextParamValue(bind_parameters, param_values, param_types[i], param_len); - } + break; + case PostgresDataFormat::BINARY: + ProcessBinaryParamValue(bind_parameters, + param_values, + param_types[i], + param_len); + break; + default:throw NetworkProcessException("Unexpected format code"); + } } } void PostgresNetworkCommand::ProcessTextParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { + int32_t len) { std::string val = in_->ReadString((size_t) len); bind_parameters.emplace_back(type::TypeId::VARCHAR, val); param_values.push_back( @@ -85,7 +92,7 @@ void PostgresNetworkCommand::ProcessTextParamValue(std::vector &b void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { + int32_t len) { switch (type) { case PostgresValueType::TINYINT: { PELOTON_ASSERT(len == sizeof(int8_t)); @@ -98,7 +105,8 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector case PostgresValueType::SMALLINT: { PELOTON_ASSERT(len == sizeof(int16_t)); auto int_val = in_->ReadValue(); - bind_parameters.emplace_back(type::TypeId::SMALLINT, std::to_string(int_val)); + bind_parameters.emplace_back(type::TypeId::SMALLINT, + std::to_string(int_val)); param_values.push_back( type::ValueFactory::GetSmallIntValue(int_val).Copy()); break; @@ -138,11 +146,28 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector break; } default: - throw NetworkProcessException("Binary Postgres protocol does not support data type " - + PostgresValueTypeToString(type)); + throw NetworkProcessException( + "Binary Postgres protocol does not support data type " + + PostgresValueTypeToString(type)); } } +std::vector PostgresNetworkCommand::ReadResultFormats(size_t tuple_size) { + auto num_format_codes = in_->ReadValue(); + switch (num_format_codes) { + case 0: + // Default text mode + return std::vector(tuple_size, + PostgresDataFormat::TEXT); + case 1: + return std::vector(tuple_size, + in_->ReadValue()); + default:std::vector result; + for (auto i = 0; i < num_format_codes; i++) + result.push_back(in_->ReadValue()); + return result; + } +} Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, @@ -154,10 +179,10 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, if (proto_version == SSL_MESSAGE_VERNO) { // TODO(Tianyu): Should this be moved from PelotonServer into settings? if (PelotonServer::GetSSLLevel() == SSLLevel::SSL_DISABLE) { - out.WriteSingleBytePacket(NetworkMessageType::SSL_NO); + out.WriteSingleTypePacket(NetworkMessageType::SSL_NO); return Transition::PROCEED; } - out.WriteSingleBytePacket(NetworkMessageType::SSL_YES); + out.WriteSingleTypePacket(NetworkMessageType::SSL_YES); return Transition::NEED_SSL_HANDSHAKE; } @@ -190,23 +215,130 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); - if(!tcop::PrepareStatement(state, query)) { + std::unique_ptr sql_stmt_list; + try { + sql_stmt_list = tcop::ParseQuery(query); + } catch (Exception &e) { + tcop::ProcessInvalidStatement(state); out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - state.error_message_}}); + e.what()}}); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - state.param_values_ = std::vector(); - interpreter.result_format_ = - std::vector(state.statement_->GetTupleDescriptor().size(), 0); - auto status = tcop::ExecuteStatement(state, - interpreter.result_format_, state.result_, callback); - if (status == ResultType::QUEUING) return Transition::NEED_RESULT; + if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) { + out.WriteEmptyQueryResponse(); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } - interpreter.ExecQueryMessageGetResult(status); + // TODO(Yuchen): Hack. We only process the first statement in the packet now. + // We should store the rest of statements that will not be processed right + // away. For the hack, in most cases, it works. Because for example in psql, + // one packet contains only one query. But when using the pipeline mode in + // Libpqxx, it sends multiple query in one packet. In this case, it's + // incorrect. + auto sql_stmt = sql_stmt_list->PassOutStatement(0); - return Transition::PROCEED; + QueryType query_type = + StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; + + switch (query_type) { + case QueryType::QUERY_PREPARE: { + std::shared_ptr statement(nullptr); + auto prep_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = prep_stmt->name; + statement = tcop::PrepareStatement(state, + stmt_name, + query, + std::move(prep_stmt->query)); + if (statement == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.statement_cache_.AddStatement(statement); + interpreter.CompleteCommand(out, query_type, 0); + // PAVLO: 2017-01-15 + // There used to be code here that would invoke this method passing + // in NetworkMessageType::READY_FOR_QUERY as the argument. But when + // I switched to strong types, this obviously doesn't work. So I + // switched it to be NetworkTransactionStateType::IDLE. I don't know + // we just don't always send back the internal txn state? + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXECUTE: { + std::vector param_values; + auto + *exec_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = exec_stmt->name; + + auto cached_statement = state.statement_cache_.GetStatement(stmt_name); + if (cached_statement != nullptr) + state.statement_ = cached_statement; + // Did not find statement with same name + else { + std::string error_message = "The prepared statement does not exist"; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + "The prepared statement does not exist"}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); + + if (!tcop::BindParamsForCachePlan(state, exec_stmt->parameters)) { + tcop::ProcessInvalidStatement(state); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + auto status = tcop::ExecuteStatement(state, + false, + callback); + if (state.is_queuing_) return Transition::NEED_RESULT; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXPLAIN: { + auto status = interpreter.ExecQueryExplain(query, + dynamic_cast(*sql_stmt)); + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + default: { + std::string stmt_name = "unamed"; + std::unique_ptr unnamed_sql_stmt_list( + new parser::SQLStatementList()); + unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); + state.statement_ = tcop::PrepareStatement(state, + stmt_name, + query, + std::move(unnamed_sql_stmt_list)); + if (state.statement_ == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.param_values_ = std::vector(); + state.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); + auto status = + tcop::ExecuteStatement(state, false, callback); + if (state.is_queuing_) + return Transition::NEED_RESULT; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + } } Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, @@ -214,24 +346,173 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string statement_name = in_->ReadString(), query = in_->ReadString(); - LOG_TRACE("Execute query: %s", query.c_str()); - if(!tcop::PrepareStatement(state, query, statement_name)) { + // In JDBC, one query starts with parsing stage. + // Reset skipped_stmt_ to false for the new query. + state.skipped_stmt_ = false; + std::unique_ptr sql_stmt_list; + QueryType query_type = QueryType::QUERY_OTHER; + try { + sql_stmt_list = tcop::ParseQuery(query); + } catch (Exception &e) { + tcop::ProcessInvalidStatement(state); + state.skipped_stmt_ = true; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + e.what()}}); + return Transition::PROCEED; + } + + // If the query is not supported yet, + // we will skip the rest commands (B,E,..) for this query + // For empty query, we still want to get it constructed + // TODO (Tianyi) Consider handle more statement + bool empty = (sql_stmt_list == nullptr || + sql_stmt_list->GetNumStatements() == 0); + if (!empty) { + parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); + query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); + } + bool skip = !interpreter.HardcodedExecuteFilter(query_type); + if (skip) { + state.skipped_stmt_ = true; + state.skipped_query_string_ = query; + state.skipped_query_type_ = query_type; + out.WriteSingleTypePacket(NetworkMessageType::PARSE_COMPLETE); + return Transition::PROCEED; + } + + auto statement = tcop::PrepareStatement(state, + statement_name, + query, + std::move(sql_stmt_list)); + if (statement == nullptr) { + tcop::ProcessInvalidStatement(state); + state.skipped_stmt_ = true; out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } + LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), query.c_str()); - // Read param types - state.statement_->SetParamTypes(ReadParamTypes()); + // Cache the received query + statement->SetParamTypes(ReadParamTypes()); + + // Cache the statement + state.statement_cache_.AddStatement(statement); + out.WriteSingleTypePacket(NetworkMessageType::PARSE_COMPLETE); + return Transition::PROCEED; +} + +Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + std::string portal_name = in_->ReadString(), + statement_name = in_->ReadString(); + if (state.skipped_stmt_) { + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; + } + + std::vector formats = ReadParamFormats(); + + // Get statement info generated in PARSE message + std::shared_ptr + statement = state.statement_cache_.GetStatement(statement_name); + if (statement == nullptr) { + std::string error_message = statement_name.empty() + ? "Invalid unnamed statement" + : "The prepared statement does not exist"; + LOG_ERROR("%s", error_message.c_str()); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + error_message}}); + return Transition::PROCEED; + } + + // Empty query + if (statement->GetQueryType() == QueryType::QUERY_INVALID) { + out.BeginPacket(NetworkMessageType::BIND_COMMAND).EndPacket(); + // TODO(Tianyi) This is a hack to respond correct describe message + // as well as execute message + state.skipped_stmt_ = true; + state.skipped_query_string_ = ""; + return Transition::PROCEED; + } + + const auto &query_string = statement->GetQueryString(); + const auto &query_type = statement->GetQueryType(); + + // check if the loaded statement needs to be skipped + state.skipped_stmt_ = false; + if (!interpreter.HardcodedExecuteFilter(query_type)) { + state.skipped_stmt_ = true; + state.skipped_query_string_ = query_string; + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; + } + + // Group the parameter types and the parameters in this vector + std::vector> bind_parameters; + std::vector param_values; - // Send Parse complete response - out.BeginPacket(NetworkMessageType::PARSE_COMPLETE).EndPacket(); + auto param_types = statement->GetParamTypes(); + ReadParamValues(bind_parameters, param_values, param_types, formats); + state.result_format_ = + ReadResultFormats(statement->GetTupleDescriptor().size()); + + if (!param_values.empty()) + statement->GetPlanTree()->SetParameterValues(¶m_values); + // Instead of tree traversal, we should put param values in the + // executor context. + + + + interpreter.portals_[portal_name] = + std::make_shared(portal_name, statement, std::move(param_values)); + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; +} + +Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + if (state.skipped_stmt_) { + // send 'no-data' + out.WriteSingleTypePacket(NetworkMessageType::NO_DATA_RESPONSE); + return Transition::PROCEED; + } + + auto mode = in_->ReadValue(); + std::string portal_name = in_->ReadString(); + switch (mode) { + case PostgresDescribeType::PORTAL: + LOG_TRACE("Describe a portal"); + auto portal_itr = interpreter.portals_.find(portal_name); + // TODO: error handling here + // Ahmed: This is causing the continuously running thread + // Changed the function signature to return boolean + // when false is returned, the connection is closed + if (portal_itr == interpreter.portals_.end()) { + LOG_ERROR("Did not find portal : %s", portal_name.c_str()); + // TODO(Tianyu): Why is this thing swallowing error? + out.WriteTupleDescriptor(std::vector()); + } else + out.WriteTupleDescriptor(portal_itr->second->GetStatement()->GetTupleDescriptor()); + break; + case PostgresDescribeType::STATEMENT: + // TODO(Tianyu): Do we not support this or something? + LOG_TRACE("Describe a prepared statement"); + break; + default: + throw NetworkProcessException("Unexpected Describe type"); + } return Transition::PROCEED; } +Transition:: + } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index bc5be0e3257..85abf2fe9aa 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -27,7 +27,7 @@ Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, curr_input_packet_.Clear(); PostgresPacketWriter writer(*out); if (command->FlushOnComplete()) out->ForceFlush(); - return command->Exec(*this, writer, callback, thread_id_); + return command->Exec(*this, writer, callback); } bool PostgresProtocolInterpreter::TryBuildPacket(std::shared_ptr &in) { diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index b1cf50cbd89..ad6e0947385 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -92,11 +92,10 @@ bool tcop::PrepareStatement(ClientProcessState &state, return true; } -bool tcop::ExecuteStatement( - ClientProcessState &state, - const std::vector &result_format, - std::vector &result, - const CallbackFunc &callback) { +bool tcop::ExecuteStatement(ClientProcessState &state, + const std::vector &result_format, + std::vector &result, + const CallbackFunc &callback) { LOG_TRACE("Execute Statement of name: %s", state.statement_->GetStatementName().c_str()); @@ -139,7 +138,25 @@ bool tcop::ExecuteStatement( state.statement_->SetNeedsReplan(false); } - ExecuteHelper(state, result, result_format, txn, callback); + auto plan = state.statement_->GetPlanTree(); + auto params = state.param_values_; + + auto on_complete = [callback, &](executor::ExecutionResult p_status, + std::vector &&values) { + state.p_status_ = p_status; + state.error_message_ = std::move(p_status.m_error_message); + result = std::move(values); + callback(); + }; + + auto &pool = threadpool::MonoQueuePool::GetInstance(); + pool.SubmitTask([txn, on_complete, &] { + executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), + txn, + state.param_values_, + result_format, + on_complete); + }); return false; } } @@ -150,32 +167,5 @@ bool tcop::ExecuteStatement( } } -void tcop::ExecuteHelper( - ClientProcessState &state, - std::vector &result, - const std::vector &result_format, - concurrency::TransactionContext *txn, - const CallbackFunc &callback) { - auto plan = state.statement_->GetPlanTree(); - auto params = state.param_values_; - - auto on_complete = [callback, &](executor::ExecutionResult p_status, - std::vector &&values) { - state.p_status_ = p_status; - state.error_message_ = std::move(p_status.m_error_message); - result = std::move(values); - callback(); - }; - - auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([txn, on_complete, &] { - executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), - txn, - state.param_values_, - result_format, - on_complete); - }); -} - } // namespace tcop } // namespace peloton \ No newline at end of file From 3832d742d064b2659b0e0f3444fd14f651fe414a Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Thu, 28 Jun 2018 18:13:01 -0400 Subject: [PATCH 26/48] Finish porting protocol handler to new code --- src/include/common/internal_types.h | 10 + src/include/common/portal.h | 10 +- src/include/network/connection_handle.h | 6 - src/include/network/network_io_wrappers.h | 3 - .../network/postgres_network_commands.h | 7 +- .../network/postgres_protocol_interpreter.h | 9 +- src/include/network/postgres_protocol_utils.h | 28 +- src/include/traffic_cop/tcop.h | 63 ++- src/network/connection_handle.cpp | 28 -- src/network/postgres_network_commands.cpp | 431 ++++++++++++++++-- src/network/postgres_protocol_interpreter.cpp | 4 +- src/traffic_cop/tcop.cpp | 56 +-- 12 files changed, 489 insertions(+), 166 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 70495dc44fd..f0d46447e12 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1452,6 +1452,16 @@ enum class SSLLevel { using CallbackFunc = std::function; using BindParameter = std::pair; +enum class PostgresDataFormat : int16_t { + TEXT = 0, + BINARY = 1 +}; + +enum class PostgresNetworkObjectType : uchar { + PORTAL = 'P', + STATEMENT = 'S' +}; + // Eigen/Matrix types used in brain // TODO(saatvik): Generalize Eigen utilities across all types typedef std::vector> matrix_t; diff --git a/src/include/common/portal.h b/src/include/common/portal.h index f4017904583..a1b7cee6ca6 100644 --- a/src/include/common/portal.h +++ b/src/include/common/portal.h @@ -31,8 +31,7 @@ class Portal { Portal &operator=(Portal &&) = delete; Portal(const std::string &portal_name, std::shared_ptr statement, - std::vector bind_parameters, - std::shared_ptr param_stat); + std::vector bind_parameters); ~Portal(); @@ -40,10 +39,6 @@ class Portal { const std::vector &GetParameters() const; - inline std::shared_ptr GetParamStat() const { - return param_stat_; - } - // Portal name std::string portal_name_; @@ -52,9 +47,6 @@ class Portal { // Values bound to the statement of this portal std::vector bind_parameters_; - - // The serialized params for stats collection - std::shared_ptr param_stat_; }; } // namespace peloton diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 4d3a6f218b0..5c9f4e4ab5a 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -44,16 +44,10 @@ namespace peloton { namespace network { /** -<<<<<<< HEAD * A ConnectionHandle encapsulates all information we need to do IO about * a client connection for its entire duration. This includes a state machine * and the necessary libevent infrastructure for a handler to work on this * connection. -======= - * @brief A ConnectionHandle encapsulates all information about a client - * connection for its entire duration. This includes a state machine and the - * necessary libevent infrastructure for a handler to work on this connection. ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 */ class ConnectionHandle { public: diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 7ffc4e74122..661002c5979 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -17,10 +17,7 @@ #include #include "common/exception.h" #include "common/utility.h" -<<<<<<< HEAD #include "network/network_types.h" -======= ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 #include "network/marshal.h" namespace peloton { diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 0ae64066e10..1274db96cd8 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -49,13 +49,13 @@ class PostgresNetworkCommand { std::vector ReadParamTypes(); - std::vector ReadParamFormats(); + std::vector ReadParamFormats(); // Why are bind parameter and param values different? void ReadParamValues(std::vector &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, - const std::vector &formats); + const std::vector &formats); void ProcessTextParamValue(std::vector &bind_parameters, std::vector ¶m_values, @@ -67,6 +67,8 @@ class PostgresNetworkCommand { PostgresValueType type, int32_t len); + std::vector ReadResultFormats(size_t tuple_size); + std::shared_ptr in_; private: bool flush_on_complete_; @@ -81,7 +83,6 @@ DEFINE_COMMAND(ExecuteCommand, false); DEFINE_COMMAND(SyncCommand, true); DEFINE_COMMAND(CloseCommand, false); DEFINE_COMMAND(TerminateCommand, true); -DEFINE_COMMAND(NullCommand, true); } // namespace network } // namespace peloton diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 3e53318a362..abd02d2abfa 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -39,11 +39,16 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { inline tcop::ClientProcessState &ClientProcessState() { return state_; } - // TODO(Tianyu): WTF is this? + + // TODO(Tianyu): Remove these later for better responsibility assignment + void CompleteCommand(PostgresPacketWriter &out, const QueryType &query_type, int rows); void ExecQueryMessageGetResult(ResultType status); void ExecExecuteMessageGetResult(ResultType status); + ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); + bool HardcodedExecuteFilter(QueryType query_type); - std::vector result_format_; + NetworkProtocolType protocol_type_; + std::unordered_map> portals_; private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index ab26741420c..91032a6de34 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -60,14 +60,21 @@ class PostgresPacketWriter { } /** - * Write out a packet with a single byte (e.g. SSL_YES or SSL_NO). This is a - * special case since no size field is provided. + * Write out a packet with a single type. Some messages will be + * special cases since no size field is provided. (SSL_YES, SSL_NO) * @param type Type of message to write out */ - inline void WriteSingleBytePacket(NetworkMessageType type) { + inline void WriteSingleTypePacket(NetworkMessageType type) { // Make sure no active packet being constructed PELOTON_ASSERT(curr_packet_len_ == nullptr); - queue_.BufferWriteRawValue(type); + switch (type) { + case NetworkMessageType::SSL_YES: + case NetworkMessageType::SSL_NO: + queue_.BufferWriteRawValue(type); + break; + default: + BeginPacket(type).EndPacket(); + } } /** @@ -121,7 +128,7 @@ class PostgresPacketWriter { } /** - * Append an integer of specified length onto the write queue. (1, 2, 4, or 8 + * Append a value of specified length onto the write queue. (1, 2, 4, or 8 * bytes). It is assumed that these bytes need to be converted to network * byte ordering. * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. @@ -129,7 +136,7 @@ class PostgresPacketWriter { * @return self-reference for chaining */ template - inline PostgresPacketWriter &AppendInt(T val) { + inline PostgresPacketWriter &AppendValue(T val) { // We only want to allow for certain type sizes to be used // After the static assert, the compiler should be smart enough to throw // away the other cases and only leave the relevant return statement. @@ -194,6 +201,15 @@ class PostgresPacketWriter { .EndPacket(); } + inline void WriteTupleDescriptor(const std::vector &tuple_descriptor) { + if (tuple_descriptor.empty()) return; + BeginPacket(NetworkMessageType::ROW_DESCRIPTION) + .AppendValue(tuple_descriptor.size()); + for (auto &col : tuple_descriptor) { + AppendString(std::get<0>(col))() + } + } + /** * End the packet. A packet write must be in progress and said write is not * well-formed until this method is called. diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 890da84def8..ed025e8e0a0 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include "executor/plan_executor.h" #include "optimizer/abstract_optimizer.h" #include "parser/postgresparser.h" @@ -27,7 +28,7 @@ using TcopTxnState = std::pair; // TODO(Tianyu): Probably need a better name // TODO(Tianyu): We can probably get rid of a bunch of fields from here struct ClientProcessState { - size_t thread_id_; + size_t thread_id_ = 0; bool is_queuing_ = false; std::string error_message_, db_name_ = DEFAULT_DB_NAME; std::vector param_values_; @@ -36,57 +37,47 @@ struct ClientProcessState { // The optimizer used for this connection std::unique_ptr optimizer_; // flag of single statement txn + bool single_statement_txn_ = false; + std::vector result_format_; + // flag of single statement txn std::vector result_; - executor::ExecutionResult p_status_; + std::stack tcop_txn_state_; + NetworkTransactionStateType txn_state_ = NetworkTransactionStateType::INVALID; + bool skipped_stmt_ = false; + std::string skipped_query_string_; + QueryType skipped_query_type_ = QueryType::QUERY_INVALID; StatementCache statement_cache_; - // Transaction Handling Wrapper - ClientTxnHandle txn_handle_; + int rows_affected_; }; -/** - * Build and optimized a statement from raw query string and store it in state - * @param state Client state context - * @param query_string raw query in string format - * @param statement_name Name of the statement stored in the statement cache - * @return true if the statement is created - */ -bool PrepareStatement( - ClientProcessState &state, const std::string &query_string, - const std::string &statement_name = "unnamed"); +inline std::unique_ptr ParseQuery(const std::string &query_string) { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); + // When the query is empty(such as ";" or ";;", still valid), + // the parse tree is empty, parser will return nullptr. + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) + throw ParserException("Error Parsing SQL statement"); + return sql_stmt_list; +} + +std::shared_ptr PrepareStatement(ClientProcessState &state, + const std::string &statement_name, + const std::string &query_string, + std::unique_ptr &&sql_stmt_list); -/** - * Execute the statemnet attached in state - * @param state Client state context - * @param result_format expected result's format specified by the protocol - * @param result the vector to store the result being returned - * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false - * @return true if the execution is finished - */ -bool ExecuteStatement( +ResultType ExecuteStatement( ClientProcessState &state, - const std::vector &result_format, - std::vector &result, CallbackFunc callback); -/** - * Helper function to submit the executable plan to worker pool - * @param state Client state context - * @param result_format expected result's format specified by the protocol - * @param result the vector to store the result being returned - * @param txn Transaction context - * @param callback callback function to be invoke after the execution is finished. Only useful if the return type is false - */ void ExecuteHelper( ClientProcessState &state, std::vector &result, - const std::vector &result_format, concurrency::TransactionContext *txn, CallbackFunc callback); bool BindParamsForCachePlan( ClientProcessState &state, - const std::vector> &, - size_t thread_id = 0); + const std::vector> &); std::vector GenerateTupleDescriptor( parser::SQLStatement *select_stmt); diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index ac57bb93f20..3f31ec4fea0 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -186,34 +186,6 @@ Transition ConnectionHandle::TrySslHandshake() { io_wrapper_); } -Transition ConnectionHandle::TryCloseConnection() { - LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); - // TODO(Tianyu): Handle close failure - Transition close = io_wrapper_->Close(); - if (close != Transition::PROCEED) return close; - // Remove listening event - // Only after the connection is closed is it safe to remove events, - // after this point no object in the system has reference to this - // connection handle and we will need to destruct and exit. - conn_handler_->UnregisterEvent(network_event_); - conn_handler_->UnregisterEvent(workpool_event_); - // This object is essentially managed by libevent (which unfortunately does - // not accept shared_ptrs.) and thus as we shut down we need to manually - // deallocate this object. - delete this; - return Transition::NONE; -} - -Transition ConnectionHandle::TrySslHandshake() { - // Flush out all the response first - if (HasResponse()) { - auto write_ret = TryWrite(); - if (write_ret != Transition::PROCEED) return write_ret; - } - return NetworkIoWrapperFactory::GetInstance().PerformSslHandshake( - io_wrapper_); -} - Transition ConnectionHandle::TryCloseConnection() { LOG_DEBUG("Attempt to close the connection %d", io_wrapper_->GetSocketFd()); // TODO(Tianyu): Handle close failure diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index b4694751c5d..21506842232 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -14,6 +14,7 @@ #include "network/peloton_server.h" #include "network/postgres_network_commands.h" #include "traffic_cop/tcop.h" +#include "settings/settings_manager.h" #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) @@ -25,7 +26,7 @@ namespace network { // A lot of the code here should really be moved to traffic cop, and a lot of // the code here can honestly just be deleted. This is going to be a larger // project though, so I want to do the architectural refactor first. -std::vector PostgresNetworkCommand::ReadParamTypes() { +std::vector PostgresNetworkCommand::ReadParamTypes() { std::vector result; auto num_params = in_->ReadValue(); for (uint16_t i = 0; i < num_params; i++) @@ -33,18 +34,19 @@ std::vector PostgresNetworkCommand::ReadParamTypes() { return result; } -std::vector PostgresNetworkCommand::ReadParamFormats() { - std::vector result; +std::vector PostgresNetworkCommand::ReadParamFormats() { + std::vector result; auto num_formats = in_->ReadValue(); for (uint16_t i = 0; i < num_formats; i++) - result.push_back(in_->ReadValue()); + result.push_back(in_->ReadValue()); return result; } void PostgresNetworkCommand::ReadParamValues(std::vector &bind_parameters, std::vector ¶m_values, const std::vector ¶m_types, - const std::vector &formats) { + const std::vector< + PostgresDataFormat> &formats) { auto num_params = in_->ReadValue(); for (uint16_t i = 0; i < num_params; i++) { auto param_len = in_->ReadValue(); @@ -55,24 +57,29 @@ void PostgresNetworkCommand::ReadParamValues(std::vector &bind_pa std::string("")); param_values.push_back(type::ValueFactory::GetNullValueByType( peloton_type)); - } else { - (formats[i] == 0) - ? ProcessTextParamValue(bind_parameters, - param_values, - param_types[i], - param_len) - : ProcessBinaryParamValue(bind_parameters, + } else + switch (formats[i]) { + case PostgresDataFormat::TEXT: + ProcessTextParamValue(bind_parameters, param_values, param_types[i], param_len); - } + break; + case PostgresDataFormat::BINARY: + ProcessBinaryParamValue(bind_parameters, + param_values, + param_types[i], + param_len); + break; + default:throw NetworkProcessException("Unexpected format code"); + } } } void PostgresNetworkCommand::ProcessTextParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { + int32_t len) { std::string val = in_->ReadString((size_t) len); bind_parameters.emplace_back(type::TypeId::VARCHAR, val); param_values.push_back( @@ -85,7 +92,7 @@ void PostgresNetworkCommand::ProcessTextParamValue(std::vector &b void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector &bind_parameters, std::vector ¶m_values, PostgresValueType type, - int32_t len) { + int32_t len) { switch (type) { case PostgresValueType::TINYINT: { PELOTON_ASSERT(len == sizeof(int8_t)); @@ -98,7 +105,8 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector case PostgresValueType::SMALLINT: { PELOTON_ASSERT(len == sizeof(int16_t)); auto int_val = in_->ReadValue(); - bind_parameters.emplace_back(type::TypeId::SMALLINT, std::to_string(int_val)); + bind_parameters.emplace_back(type::TypeId::SMALLINT, + std::to_string(int_val)); param_values.push_back( type::ValueFactory::GetSmallIntValue(int_val).Copy()); break; @@ -138,11 +146,28 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector break; } default: - throw NetworkProcessException("Binary Postgres protocol does not support data type " - + PostgresValueTypeToString(type)); + throw NetworkProcessException( + "Binary Postgres protocol does not support data type " + + PostgresValueTypeToString(type)); } } +std::vector PostgresNetworkCommand::ReadResultFormats(size_t tuple_size) { + auto num_format_codes = in_->ReadValue(); + switch (num_format_codes) { + case 0: + // Default text mode + return std::vector(tuple_size, + PostgresDataFormat::TEXT); + case 1: + return std::vector(tuple_size, + in_->ReadValue()); + default:std::vector result; + for (auto i = 0; i < num_format_codes; i++) + result.push_back(in_->ReadValue()); + return result; + } +} Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, @@ -154,10 +179,10 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, if (proto_version == SSL_MESSAGE_VERNO) { // TODO(Tianyu): Should this be moved from PelotonServer into settings? if (PelotonServer::GetSSLLevel() == SSLLevel::SSL_DISABLE) { - out.WriteSingleBytePacket(NetworkMessageType::SSL_NO); + out.WriteSingleTypePacket(NetworkMessageType::SSL_NO); return Transition::PROCEED; } - out.WriteSingleBytePacket(NetworkMessageType::SSL_YES); + out.WriteSingleTypePacket(NetworkMessageType::SSL_YES); return Transition::NEED_SSL_HANDSHAKE; } @@ -190,23 +215,129 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); - if(!tcop::PrepareStatement(state, query)) { + std::unique_ptr sql_stmt_list; + try { + sql_stmt_list = tcop::ParseQuery(query); + } catch (Exception &e) { + tcop::ProcessInvalidStatement(state); out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - state.error_message_}}); + e.what()}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) { + out.WriteEmptyQueryResponse(); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - state.param_values_ = std::vector(); - interpreter.result_format_ = - std::vector(state.statement_->GetTupleDescriptor().size(), 0); - auto status = tcop::ExecuteStatement(state, - interpreter.result_format_, state.result_, callback); - if (status == ResultType::QUEUING) return Transition::NEED_RESULT; + // TODO(Yuchen): Hack. We only process the first statement in the packet now. + // We should store the rest of statements that will not be processed right + // away. For the hack, in most cases, it works. Because for example in psql, + // one packet contains only one query. But when using the pipeline mode in + // Libpqxx, it sends multiple query in one packet. In this case, it's + // incorrect. + auto sql_stmt = sql_stmt_list->PassOutStatement(0); - interpreter.ExecQueryMessageGetResult(status); + QueryType query_type = + StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; - return Transition::PROCEED; + switch (query_type) { + case QueryType::QUERY_PREPARE: { + std::shared_ptr statement(nullptr); + auto prep_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = prep_stmt->name; + statement = tcop::PrepareStatement(state, + stmt_name, + query, + std::move(prep_stmt->query)); + if (statement == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.statement_cache_.AddStatement(statement); + interpreter.CompleteCommand(out, query_type, 0); + // PAVLO: 2017-01-15 + // There used to be code here that would invoke this method passing + // in NetworkMessageType::READY_FOR_QUERY as the argument. But when + // I switched to strong types, this obviously doesn't work. So I + // switched it to be NetworkTransactionStateType::IDLE. I don't know + // we just don't always send back the internal txn state? + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXECUTE: { + std::vector param_values; + auto + *exec_stmt = dynamic_cast(sql_stmt.get()); + std::string stmt_name = exec_stmt->name; + + auto cached_statement = state.statement_cache_.GetStatement(stmt_name); + if (cached_statement != nullptr) + state.statement_ = cached_statement; + // Did not find statement with same name + else { + std::string error_message = "The prepared statement does not exist"; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + "The prepared statement does not exist"}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); + + if (!tcop::BindParamsForCachePlan(state, exec_stmt->parameters)) { + tcop::ProcessInvalidStatement(state); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + + auto status = tcop::ExecuteStatement(state, + callback); + if (state.is_queuing_) return Transition::NEED_RESULT; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + }; + case QueryType::QUERY_EXPLAIN: { + auto status = interpreter.ExecQueryExplain(query, + dynamic_cast(*sql_stmt)); + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + default: { + std::string stmt_name = "unamed"; + std::unique_ptr unnamed_sql_stmt_list( + new parser::SQLStatementList()); + unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); + state.statement_ = tcop::PrepareStatement(state, + stmt_name, + query, + std::move(unnamed_sql_stmt_list)); + if (state.statement_ == nullptr) { + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return Transition::PROCEED; + } + state.param_values_ = std::vector(); + state.result_format_ = + std::vector(state.statement_->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); + auto status = + tcop::ExecuteStatement(state, callback); + if (state.is_queuing_) + return Transition::NEED_RESULT; + interpreter.ExecQueryMessageGetResult(status); + return Transition::PROCEED; + } + } } Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, @@ -214,24 +345,250 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string statement_name = in_->ReadString(), query = in_->ReadString(); - LOG_TRACE("Execute query: %s", query.c_str()); - if(!tcop::PrepareStatement(state, query, statement_name)) { + // In JDBC, one query starts with parsing stage. + // Reset skipped_stmt_ to false for the new query. + state.skipped_stmt_ = false; + std::unique_ptr sql_stmt_list; + QueryType query_type = QueryType::QUERY_OTHER; + try { + sql_stmt_list = tcop::ParseQuery(query); + } catch (Exception &e) { + tcop::ProcessInvalidStatement(state); + state.skipped_stmt_ = true; + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + e.what()}}); + return Transition::PROCEED; + } + + // If the query is not supported yet, + // we will skip the rest commands (B,E,..) for this query + // For empty query, we still want to get it constructed + // TODO (Tianyi) Consider handle more statement + bool empty = (sql_stmt_list == nullptr || + sql_stmt_list->GetNumStatements() == 0); + if (!empty) { + parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); + query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); + } + bool skip = !interpreter.HardcodedExecuteFilter(query_type); + if (skip) { + state.skipped_stmt_ = true; + state.skipped_query_string_ = query; + state.skipped_query_type_ = query_type; + out.WriteSingleTypePacket(NetworkMessageType::PARSE_COMPLETE); + return Transition::PROCEED; + } + + auto statement = tcop::PrepareStatement(state, + statement_name, + query, + std::move(sql_stmt_list)); + if (statement == nullptr) { + tcop::ProcessInvalidStatement(state); + state.skipped_stmt_ = true; out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); - out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } + LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), query.c_str()); - // Read param types - state.statement_->SetParamTypes(ReadParamTypes()); + // Cache the received query + statement->SetParamTypes(ReadParamTypes()); + + // Cache the statement + state.statement_cache_.AddStatement(statement); + out.WriteSingleTypePacket(NetworkMessageType::PARSE_COMPLETE); + return Transition::PROCEED; +} + +Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + std::string portal_name = in_->ReadString(), + statement_name = in_->ReadString(); + if (state.skipped_stmt_) { + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; + } + + std::vector formats = ReadParamFormats(); + + // Get statement info generated in PARSE message + std::shared_ptr + statement = state.statement_cache_.GetStatement(statement_name); + if (statement == nullptr) { + std::string error_message = statement_name.empty() + ? "Invalid unnamed statement" + : "The prepared statement does not exist"; + LOG_ERROR("%s", error_message.c_str()); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + error_message}}); + return Transition::PROCEED; + } + + // Empty query + if (statement->GetQueryType() == QueryType::QUERY_INVALID) { + out.BeginPacket(NetworkMessageType::BIND_COMMAND).EndPacket(); + // TODO(Tianyi) This is a hack to respond correct describe message + // as well as execute message + state.skipped_stmt_ = true; + state.skipped_query_string_ = ""; + return Transition::PROCEED; + } + + const auto &query_string = statement->GetQueryString(); + const auto &query_type = statement->GetQueryType(); - // Send Parse complete response - out.BeginPacket(NetworkMessageType::PARSE_COMPLETE).EndPacket(); + // check if the loaded statement needs to be skipped + state.skipped_stmt_ = false; + if (!interpreter.HardcodedExecuteFilter(query_type)) { + state.skipped_stmt_ = true; + state.skipped_query_string_ = query_string; + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; + } + + // Group the parameter types and the parameters in this vector + std::vector> bind_parameters; + std::vector param_values; + + auto param_types = statement->GetParamTypes(); + ReadParamValues(bind_parameters, param_values, param_types, formats); + state.result_format_ = + ReadResultFormats(statement->GetTupleDescriptor().size()); + + if (!param_values.empty()) + statement->GetPlanTree()->SetParameterValues(¶m_values); + // Instead of tree traversal, we should put param values in the + // executor context. + + + + interpreter.portals_[portal_name] = + std::make_shared(portal_name, statement, std::move(param_values)); + out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); + return Transition::PROCEED; +} + +Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + if (state.skipped_stmt_) { + // send 'no-data' + out.WriteSingleTypePacket(NetworkMessageType::NO_DATA_RESPONSE); + return Transition::PROCEED; + } + + auto mode = in_->ReadValue(); + std::string portal_name = in_->ReadString(); + switch (mode) { + case PostgresNetworkObjectType::PORTAL:LOG_TRACE("Describe a portal"); + auto portal_itr = interpreter.portals_.find(portal_name); + // TODO: error handling here + // Ahmed: This is causing the continuously running thread + // Changed the function signature to return boolean + // when false is returned, the connection is closed + if (portal_itr == interpreter.portals_.end()) { + LOG_ERROR("Did not find portal : %s", portal_name.c_str()); + // TODO(Tianyu): Why is this thing swallowing error? + out.WriteTupleDescriptor(std::vector()); + } else + out.WriteTupleDescriptor(portal_itr->second->GetStatement()->GetTupleDescriptor()); + break; + case PostgresNetworkObjectType::STATEMENT: + // TODO(Tianyu): Do we not support this or something? + LOG_TRACE("Describe a prepared statement"); + break; + default:throw NetworkProcessException("Unexpected Describe type"); + } + return Transition::PROCEED; +} + +Transition ExecuteCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc callback) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_JDBC; + std::string portal_name = in_->ReadString(); + + // covers weird JDBC edge case of sending double BEGIN statements. Don't + // execute them + if (state.skipped_stmt_) { + if (state.skipped_query_string_ == "") + out.WriteEmptyQueryResponse(); + else + interpreter.CompleteCommand(out, + state.skipped_query_type_, + state.rows_affected_); + state.skipped_stmt_ = false; + return Transition::PROCEED; + } + + auto portal_itr = interpreter.portals_.find(portal_name); + if (portal_itr == interpreter.portals_.end()) + throw NetworkProcessException("Did not find portal: " + portal_name); + + std::shared_ptr portal = portal_itr->second; + state.statement_ = portal->GetStatement(); + auto param_stat = portal->GetParamStat(); + + if (state.statement_ == nullptr) + throw NetworkProcessException( + "Did not find statement in portal: " + portal_name); + + state.param_values_ = portal->GetParameters(); + auto status = tcop::ExecuteStatement(state, callback); + if (state.is_queuing_) return Transition::NEED_RESULT; + interpreter.ExecExecuteMessageGetResult(status); + return Transition::PROCEED; +} + +Transition SyncCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + out.WriteReadyForQuery(state.txn_state_); return Transition::PROCEED; } +Transition CloseCommand::Exec(PostgresProtocolInterpreter &interpreter, + PostgresPacketWriter &out, + CallbackFunc) { + tcop::ClientProcessState &state = interpreter.ClientProcessState(); + auto close_type = in_->ReadValue(); + std::string name = in_->ReadString(); + switch (close_type) { + case PostgresNetworkObjectType::STATEMENT: { + LOG_TRACE("Deleting statement %s from cache", name.c_str()); + state.statement_cache_.DeleteStatement(name); + break; + } + case 'P': { + LOG_TRACE("Deleting portal %s from cache", name.c_str()); + auto portal_itr = interpreter.portals_.find(name); + if (portal_itr != interpreter.portals_.end()) { + // delete portal if it exists + interpreter.portals_.erase(portal_itr); + } + break; + } + default: + // do nothing, simply send close complete + break; + } + // Send close complete response + out.WriteSingleTypePacket(NetworkMessageType::CLOSE_COMPLETE); +} + +Transition TerminateCommand(PostgresProtocolInterpreter &, + PostgresPacketWriter &, + CallbackFunc) { + return Transition::TERMINATE; +}\ } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index bc5be0e3257..442f16ddfdf 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -27,7 +27,7 @@ Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, curr_input_packet_.Clear(); PostgresPacketWriter writer(*out); if (command->FlushOnComplete()) out->ForceFlush(); - return command->Exec(*this, writer, callback, thread_id_); + return command->Exec(*this, writer, callback); } bool PostgresProtocolInterpreter::TryBuildPacket(std::shared_ptr &in) { @@ -95,8 +95,6 @@ std::shared_ptr PostgresProtocolInterpreter::PacketToCom return MAKE_COMMAND(CloseCommand); case NetworkMessageType::TERMINATE_COMMAND: return MAKE_COMMAND(TerminateCommand); - case NetworkMessageType::NULL_COMMAND: - return MAKE_COMMAND(NullCommand); default: throw NetworkProcessException("Unexpected Packet Type: " + std::to_string(static_cast(curr_input_packet_.msg_type_))); diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index b1cf50cbd89..ad6e0947385 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -92,11 +92,10 @@ bool tcop::PrepareStatement(ClientProcessState &state, return true; } -bool tcop::ExecuteStatement( - ClientProcessState &state, - const std::vector &result_format, - std::vector &result, - const CallbackFunc &callback) { +bool tcop::ExecuteStatement(ClientProcessState &state, + const std::vector &result_format, + std::vector &result, + const CallbackFunc &callback) { LOG_TRACE("Execute Statement of name: %s", state.statement_->GetStatementName().c_str()); @@ -139,7 +138,25 @@ bool tcop::ExecuteStatement( state.statement_->SetNeedsReplan(false); } - ExecuteHelper(state, result, result_format, txn, callback); + auto plan = state.statement_->GetPlanTree(); + auto params = state.param_values_; + + auto on_complete = [callback, &](executor::ExecutionResult p_status, + std::vector &&values) { + state.p_status_ = p_status; + state.error_message_ = std::move(p_status.m_error_message); + result = std::move(values); + callback(); + }; + + auto &pool = threadpool::MonoQueuePool::GetInstance(); + pool.SubmitTask([txn, on_complete, &] { + executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), + txn, + state.param_values_, + result_format, + on_complete); + }); return false; } } @@ -150,32 +167,5 @@ bool tcop::ExecuteStatement( } } -void tcop::ExecuteHelper( - ClientProcessState &state, - std::vector &result, - const std::vector &result_format, - concurrency::TransactionContext *txn, - const CallbackFunc &callback) { - auto plan = state.statement_->GetPlanTree(); - auto params = state.param_values_; - - auto on_complete = [callback, &](executor::ExecutionResult p_status, - std::vector &&values) { - state.p_status_ = p_status; - state.error_message_ = std::move(p_status.m_error_message); - result = std::move(values); - callback(); - }; - - auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([txn, on_complete, &] { - executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), - txn, - state.param_values_, - result_format, - on_complete); - }); -} - } // namespace tcop } // namespace peloton \ No newline at end of file From 66c5444cfe314e76e62aa6bc15c25431d01104d1 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 29 Jun 2018 14:40:42 -0400 Subject: [PATCH 27/48] Finish porting over old code --- src/common/portal.cpp | 6 +- src/executor/copy_executor.cpp | 12 +- src/include/common/internal_types.h | 4 - src/include/network/connection_handle.h | 1 + src/include/network/marshal.h | 155 ++ src/include/network/network_io_utils.h | 14 +- src/include/network/peloton_server.h | 1 - .../network/postgres_network_commands.h | 1 - .../network/postgres_protocol_handler.h | 240 ---- .../network/postgres_protocol_interpreter.h | 13 +- src/include/network/postgres_protocol_utils.h | 49 +- src/include/network/protocol_handler.h | 60 - .../network/protocol_handler_factory.h | 36 - src/include/network/protocol_interpreter.h | 4 - .../traffic_cop/client_transaction_handle.h | 230 --- src/include/traffic_cop/tcop.h | 100 +- src/include/traffic_cop/traffic_cop.h | 202 --- src/network/connection_handle.cpp | 2 - src/network/postgres_network_commands.cpp | 92 +- src/network/postgres_protocol_handler.cpp | 1272 ----------------- src/network/postgres_protocol_interpreter.cpp | 155 +- src/network/protocol_handler.cpp | 38 - src/network/protocol_handler_factory.cpp | 30 - src/traffic_cop/client_transaction_handle.cpp | 107 -- src/traffic_cop/tcop.cpp | 541 +++++-- src/traffic_cop/traffic_cop.cpp | 617 -------- test/sql/testing_sql_util.cpp | 2 +- 27 files changed, 931 insertions(+), 3053 deletions(-) delete mode 100644 src/include/network/postgres_protocol_handler.h delete mode 100644 src/include/network/protocol_handler.h delete mode 100644 src/include/network/protocol_handler_factory.h delete mode 100644 src/include/traffic_cop/client_transaction_handle.h delete mode 100644 src/include/traffic_cop/traffic_cop.h delete mode 100644 src/network/postgres_protocol_handler.cpp delete mode 100644 src/network/protocol_handler.cpp delete mode 100644 src/network/protocol_handler_factory.cpp delete mode 100644 src/traffic_cop/client_transaction_handle.cpp delete mode 100644 src/traffic_cop/traffic_cop.cpp diff --git a/src/common/portal.cpp b/src/common/portal.cpp index 77de6522f50..a5aa7754ba4 100644 --- a/src/common/portal.cpp +++ b/src/common/portal.cpp @@ -18,12 +18,10 @@ namespace peloton { Portal::Portal(const std::string& portal_name, std::shared_ptr statement, - std::vector bind_parameters, - std::shared_ptr param_stat) + std::vector bind_parameters) : portal_name_(portal_name), statement_(statement), - bind_parameters_(std::move(bind_parameters)), - param_stat_(param_stat) {} + bind_parameters_(std::move(bind_parameters)) {} Portal::~Portal() { statement_.reset(); } diff --git a/src/executor/copy_executor.cpp b/src/executor/copy_executor.cpp index f499e899708..b10ce872747 100644 --- a/src/executor/copy_executor.cpp +++ b/src/executor/copy_executor.cpp @@ -22,9 +22,9 @@ #include "executor/logical_tile_factory.h" #include "planner/export_external_file_plan.h" #include "storage/table_factory.h" -#include "network/postgres_protocol_handler.h" #include "common/exception.h" #include "common/macros.h" +#include "network/marshal.h" namespace peloton { namespace executor { @@ -202,7 +202,7 @@ bool CopyExecutor::DExecute() { // Read param types types.resize(num_params); //TODO: Instead of passing packet to executor, some data structure more generic is need - network::PostgresProtocolHandler::ReadParamType(&packet, num_params, types); + network::OldReadParamType(&packet, num_params, types); // Write all the types to output file for (int i = 0; i < num_params; i++) { @@ -219,7 +219,7 @@ bool CopyExecutor::DExecute() { // Read param formats formats.resize(num_params); //TODO: Instead of passing packet to executor, some data structure more generic is need - network::PostgresProtocolHandler::ReadParamFormat(&packet, num_params, formats); + network::OldReadParamFormat(&packet, num_params, formats); } else if (origin_col_id == param_val_col_id) { // param_values column @@ -230,9 +230,9 @@ bool CopyExecutor::DExecute() { bind_parameters.resize(num_params); param_values.resize(num_params); //TODO: Instead of passing packet to executor, some data structure more generic is need - network::PostgresProtocolHandler::ReadParamValue(&packet, num_params, types, - bind_parameters, param_values, - formats); + network::OldReadParamValue(&packet, num_params, types, + bind_parameters, param_values, + formats); // Write all the values to output file for (int i = 0; i < num_params; i++) { diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 333c1cd0d67..f0d46447e12 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1457,11 +1457,7 @@ enum class PostgresDataFormat : int16_t { BINARY = 1 }; -<<<<<<< HEAD enum class PostgresNetworkObjectType : uchar { -======= -enum class PostgresDescribeType : uchar { ->>>>>>> aaa37d1d17aee6e14c87188c769aab3db253b19c PORTAL = 'P', STATEMENT = 'S' }; diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 5c9f4e4ab5a..700e1100164 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -92,6 +92,7 @@ class ConnectionHandle { inline Transition TryWrite() { if (io_wrapper_->ShouldFlush()) return io_wrapper_->FlushAllWrites(); + return Transition::PROCEED; } inline Transition Process() { diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index 56a33a13a86..f46ef77605e 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -12,6 +12,7 @@ #pragma once +#include "type/value_factory.h" #include "network/network_io_utils.h" @@ -160,5 +161,159 @@ extern void PacketGetByte(InputPacket *rpkt, uchar &result); */ extern void GetStringToken(InputPacket *pkt, std::string &result); +// TODO(Tianyu): These dumb things are here because copy_executor somehow calls +// our network layer. This should NOT be the case. Will remove. +size_t OldReadParamType( + InputPacket *pkt, int num_params, std::vector ¶m_types) { + auto begin = pkt->ptr; + // get the type of each parameter + for (int i = 0; i < num_params; i++) { + int param_type = PacketGetInt(pkt, 4); + param_types[i] = param_type; + } + auto end = pkt->ptr; + return end - begin; +} + +size_t OldReadParamFormat(InputPacket *pkt, + int num_params_format, + std::vector &formats) { + auto begin = pkt->ptr; + // get the format of each parameter + for (int i = 0; i < num_params_format; i++) { + formats[i] = PacketGetInt(pkt, 2); + } + auto end = pkt->ptr; + return end - begin; +} + +// For consistency, this function assumes the input vectors has the correct size +size_t OldReadParamValue( + InputPacket *pkt, int num_params, std::vector ¶m_types, + std::vector> &bind_parameters, + std::vector ¶m_values, std::vector &formats) { + auto begin = pkt->ptr; + ByteBuf param; + for (int param_idx = 0; param_idx < num_params; param_idx++) { + int param_len = PacketGetInt(pkt, 4); + // BIND packet NULL parameter case + if (param_len == -1) { + // NULL mode + auto peloton_type = PostgresValueTypeToPelotonValueType( + static_cast(param_types[param_idx])); + bind_parameters[param_idx] = + std::make_pair(peloton_type, std::string("")); + param_values[param_idx] = + type::ValueFactory::GetNullValueByType(peloton_type); + } else { + PacketGetBytes(pkt, param_len, param); + + if (formats[param_idx] == 0) { + // TEXT mode + std::string param_str = std::string(std::begin(param), std::end(param)); + bind_parameters[param_idx] = + std::make_pair(type::TypeId::VARCHAR, param_str); + if ((unsigned int)param_idx >= param_types.size() || + PostgresValueTypeToPelotonValueType( + (PostgresValueType)param_types[param_idx]) == + type::TypeId::VARCHAR) { + param_values[param_idx] = + type::ValueFactory::GetVarcharValue(param_str); + } else { + param_values[param_idx] = + (type::ValueFactory::GetVarcharValue(param_str)) + .CastAs(PostgresValueTypeToPelotonValueType( + (PostgresValueType)param_types[param_idx])); + } + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); + } else { + // BINARY mode + PostgresValueType pg_value_type = + static_cast(param_types[param_idx]); + LOG_TRACE("Postgres Protocol Conversion [param_idx=%d]", param_idx); + switch (pg_value_type) { + case PostgresValueType::TINYINT: { + int8_t int_val = 0; + for (size_t i = 0; i < sizeof(int8_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::TINYINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetTinyIntValue(int_val).Copy(); + break; + } + case PostgresValueType::SMALLINT: { + int16_t int_val = 0; + for (size_t i = 0; i < sizeof(int16_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetSmallIntValue(int_val).Copy(); + break; + } + case PostgresValueType::INTEGER: { + int32_t int_val = 0; + for (size_t i = 0; i < sizeof(int32_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::INTEGER, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetIntegerValue(int_val).Copy(); + break; + } + case PostgresValueType::BIGINT: { + int64_t int_val = 0; + for (size_t i = 0; i < sizeof(int64_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::BIGINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetBigIntValue(int_val).Copy(); + break; + } + case PostgresValueType::DOUBLE: { + double float_val = 0; + unsigned long buf = 0; + for (size_t i = 0; i < sizeof(double); ++i) { + buf = (buf << 8) | param[i]; + } + PELOTON_MEMCPY(&float_val, &buf, sizeof(double)); + bind_parameters[param_idx] = std::make_pair( + type::TypeId::DECIMAL, std::to_string(float_val)); + param_values[param_idx] = + type::ValueFactory::GetDecimalValue(float_val).Copy(); + break; + } + case PostgresValueType::VARBINARY: { + bind_parameters[param_idx] = std::make_pair( + type::TypeId::VARBINARY, + std::string(reinterpret_cast(¶m[0]), param_len)); + param_values[param_idx] = type::ValueFactory::GetVarbinaryValue( + ¶m[0], param_len, true); + break; + } + default: { + LOG_ERROR( + "Binary Postgres protocol does not support data type '%s' [%d]", + PostgresValueTypeToString(pg_value_type).c_str(), + param_types[param_idx]); + break; + } + } + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); + } + } + } + auto end = pkt->ptr; + return end - begin; +} + } // namespace network } // namespace peloton diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index b1afef54661..c6581298215 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -20,6 +20,7 @@ namespace peloton { namespace network { +#define _CAST(type, val) ((type)(val)) /** * A plain old buffer with a movable cursor, the meaning of which is dependent * on the use case. @@ -172,11 +173,15 @@ class ReadBuffer : public Buffer { || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8, "Invalid size for integer"); + auto val = ReadRawValue(); switch (sizeof(T)) { - case 1: return ReadRawValue(); - case 2: return ntohs(ReadRawValue()); - case 4: return ntohl(ReadRawValue()); - case 8: return ntohll(ReadRawValue()); + case 1: return val; + case 2: + return _CAST(T, ntohs(_CAST(uint16_t, val))); + case 4: + return _CAST(T, ntohl(_CAST(uint32_t, val))); + case 8: + return _CAST(T, ntohll(_CAST(uint64_t, val))); // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } @@ -395,6 +400,7 @@ class WriteQueue { private: friend class PostgresPacketWriter; std::vector> buffers_; + size_t offset_ = 0; bool flush_ = false; }; diff --git a/src/include/network/peloton_server.h b/src/include/network/peloton_server.h index eac292a6e1a..076f99c5b31 100644 --- a/src/include/network/peloton_server.h +++ b/src/include/network/peloton_server.h @@ -35,7 +35,6 @@ #include "common/notifiable_task.h" #include "connection_dispatcher_task.h" #include "network_types.h" -#include "protocol_handler.h" #include #include diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 1274db96cd8..72348581ee0 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -16,7 +16,6 @@ #include "common/logger.h" #include "common/macros.h" #include "network/network_types.h" -#include "traffic_cop/traffic_cop.h" #include "network/marshal.h" #include "network/postgres_protocol_utils.h" diff --git a/src/include/network/postgres_protocol_handler.h b/src/include/network/postgres_protocol_handler.h deleted file mode 100644 index 960e2fdfd46..00000000000 --- a/src/include/network/postgres_protocol_handler.h +++ /dev/null @@ -1,240 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// postgres_protocol_handler.h -// -// Identification: src/include/network/postgres_protocol_handler.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include -#include -#include -#include - -#include "common/cache.h" -#include "common/internal_types.h" -#include "common/portal.h" -#include "common/statement.h" -#include "common/statement_cache.h" -#include "protocol_handler.h" -#include "traffic_cop/traffic_cop.h" - -// Packet content macros -#define NULL_CONTENT_SIZE (-1) - -namespace peloton { - -namespace parser { -class ExplainStatement; -} // namespace parser - -namespace network { - -typedef std::vector> ResponseBuffer; - -class PostgresProtocolHandler : public ProtocolHandler { - public: - PostgresProtocolHandler(tcop::TrafficCop *traffic_cop); - - ~PostgresProtocolHandler(); - /** - * Parse the content in the buffer and process to generate results. - * @param rbuf The read buffer of network - * @param thread_id The thread of current running thread. This is used - * to generate txn - * @return @see ProcessResult - */ - ProcessResult Process(ReadBuffer &rbuf, size_t thread_id); - - // Deserialize the parame types from packet - static size_t ReadParamType(InputPacket *pkt, int num_params, - std::vector ¶m_types); - - // Deserialize the parameter format from packet - static size_t ReadParamFormat(InputPacket *pkt, int num_params_format, - std::vector &formats); - - // Deserialize the parameter value from packet - static size_t ReadParamValue( - InputPacket *pkt, int num_params, std::vector ¶m_types, - std::vector> &bind_parameters, - std::vector ¶m_values, std::vector &formats); - - void Reset(); - - void GetResult(); - - private: - //===--------------------------------------------------------------------===// - // STATIC HELPERS - //===--------------------------------------------------------------------===// - - /** - * @brief Parse the input packet from rbuf - * @param rbuf network read buffer - * @param rpkt the postgres rpkt we want to parse to - * @param startup_format whether we want the rpkt to be of startup packet - * format - * (i.e. no type byte) - * @return true if the parsing is complete - */ - static bool ParseInputPacket(ReadBuffer &rbuf, InputPacket &rpkt, - bool startup_format); - - /** - * @brief Helper function to extract the body of Postgres packet from the - * read buffer - * @param rbuf network read buffer - * @param rpkt the postgres rpkt we want to parse to - * @return true if the parsing is complete - */ - static bool ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt); - - /** - * @brief Helper function to extract the header of a Postgres packet from the - * read buffer - * @see ParseInputPacket from param and return value - */ - static bool ReadPacketHeader(ReadBuffer &rbuf, InputPacket &rpkt, - bool startup_format); - - //===--------------------------------------------------------------------===// - // PROTOCOL HANDLING FUNCTIONS - //===--------------------------------------------------------------------===// - - /** - * @brief Routine to deal with the first packet from the client - */ - ProcessResult ProcessInitialPacket(InputPacket *pkt); - - /** - * @brief Main Switch function to process general packets - */ - ProcessResult ProcessNormalPacket(InputPacket *pkt, const size_t thread_id); - - /** - * @brief Helper function to process startup packet - * @param proto_version protocol version of the session - */ - ProcessResult ProcessStartupPacket(InputPacket *pkt, int32_t proto_version); - - /** - * Send hardcoded response - */ - void SendStartupResponse(); - - // Generic error protocol packet - void SendErrorResponse( - std::vector> error_status); - - // Sends ready for query packet to the frontend - void SendReadyForQuery(NetworkTransactionStateType txn_status); - - // Sends the attribute headers required by SELECT queries - void PutTupleDescriptor(const std::vector &tuple_descriptor); - - // Send each row, one packet at a time, used by SELECT queries - void SendDataRows(std::vector &results, int colcount); - - // Used to send a packet that indicates the completion of a query. Also has - // txn state mgmt - void CompleteCommand(const QueryType &query_type, int rows); - - // Specific response for empty or NULL queries - void SendEmptyQueryResponse(); - - /* Helper function used to make hardcoded ParameterStatus('S') - * packets during startup - */ - void MakeHardcodedParameterStatus( - const std::pair &kv); - - /* We don't support "SET" and "SHOW" SQL commands yet. - * Also, duplicate BEGINs and COMMITs shouldn't be executed. - * This function helps filtering out the execution for such cases - */ - bool HardcodedExecuteFilter(QueryType query_type); - - /* Execute a Simple query protocol message */ - ProcessResult ExecQueryMessage(InputPacket *pkt, const size_t thread_id); - - /* Execute a EXPLAIN query message */ - ResultType ExecQueryExplain(const std::string &query, - parser::ExplainStatement &explain_stmt); - - /* Process the PARSE message of the extended query protocol */ - void ExecParseMessage(InputPacket *pkt); - - /* Process the BIND message of the extended query protocol */ - void ExecBindMessage(InputPacket *pkt); - - /* Process the DESCRIBE message of the extended query protocol */ - ProcessResult ExecDescribeMessage(InputPacket *pkt); - - /* Process the EXECUTE message of the extended query protocol */ - ProcessResult ExecExecuteMessage(InputPacket *pkt, const size_t thread_id); - - /* Process the optional CLOSE message of the extended query protocol */ - void ExecCloseMessage(InputPacket *pkt); - - void ExecExecuteMessageGetResult(ResultType status); - - void ExecQueryMessageGetResult(ResultType status); - - //===--------------------------------------------------------------------===// - // MEMBERS - //===--------------------------------------------------------------------===// - // True if this protocol is handling startup/SSL packets - bool init_stage_; - - NetworkProtocolType protocol_type_; - - // The result-column format code - std::vector result_format_; - - // global txn state - NetworkTransactionStateType txn_state_; - - // state to manage skipped queries - bool skipped_stmt_ = false; - std::string skipped_query_string_; - QueryType skipped_query_type_; - - // Statement cache - StatementCache statement_cache_; - - // Portals - std::unordered_map> portals_; - - // packets ready for read - size_t pkt_cntr_; - - // Manage parameter types for unnamed statement - stats::QueryMetric::QueryParamBuf unnamed_stmt_param_types_; - - // Parameter types for statements - // Warning: the data in the param buffer becomes invalid when the value - // stored - // in stat table is destroyed - std::unordered_map - statement_param_types_; - - std::unordered_map cmdline_options_; - - //===--------------------------------------------------------------------===// - // STATIC DATA - //===--------------------------------------------------------------------===// - - static const std::unordered_map - parameter_status_map_; -}; - -} // namespace network -} // namespace peloton diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index abd02d2abfa..15e197351cc 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -15,6 +15,7 @@ #include "network/protocol_interpreter.h" #include "network/postgres_network_commands.h" #include "traffic_cop/tcop.h" +#include "common/portal.h" namespace peloton { namespace network { @@ -23,7 +24,9 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { public: // TODO(Tianyu): Is this even the right thread id? It seems that all the // concurrency code is dependent on this number. - explicit PostgresProtocolInterpreter(size_t thread_id) = default; + explicit PostgresProtocolInterpreter(size_t thread_id) { + state_.thread_id_= thread_id; + }; Transition Process(std::shared_ptr in, std::shared_ptr out, @@ -41,20 +44,18 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { // TODO(Tianyu): Remove these later for better responsibility assignment + bool HardcodedExecuteFilter(QueryType query_type); void CompleteCommand(PostgresPacketWriter &out, const QueryType &query_type, int rows); - void ExecQueryMessageGetResult(ResultType status); - void ExecExecuteMessageGetResult(ResultType status); + void ExecQueryMessageGetResult(PostgresPacketWriter &out, ResultType status); + void ExecExecuteMessageGetResult(PostgresPacketWriter &out, ResultType status); ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); - bool HardcodedExecuteFilter(QueryType query_type); - NetworkProtocolType protocol_type_; std::unordered_map> portals_; private: bool startup_ = true; PostgresInputPacket curr_input_packet_{}; std::unordered_map cmdline_options_; tcop::ClientProcessState state_; - bool TryBuildPacket(std::shared_ptr &in); bool TryReadPacketHeader(std::shared_ptr &in); std::shared_ptr PacketToCommand(); diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index 91032a6de34..cda1216df97 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -12,7 +12,9 @@ #pragma once #include "network/network_io_utils.h" +#include "common/statement.h" +#define NULL_CONTENT_SIZE (-1) namespace peloton { namespace network { @@ -144,11 +146,12 @@ class PostgresPacketWriter { || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8, "Invalid size for integer"); + switch (sizeof(T)) { - case 1: return AppendRawValue(val); - case 2: return AppendRawValue(ntohs(val)); - case 4: return AppendRawValue(ntohl(val)); - case 8: return AppendRawValue(ntohll(val)); + case 1: return AppendRawValue(val); + case 2: return AppendRawValue(_CAST(T, ntohs(_CAST(uint16_t, val)))); + case 4: return AppendRawValue(_CAST(T, ntohl(_CAST(uint32_t, val)))); + case 8: return AppendRawValue(_CAST(T, ntohll(_CAST(uint64_t, val)))); // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } @@ -203,10 +206,42 @@ class PostgresPacketWriter { inline void WriteTupleDescriptor(const std::vector &tuple_descriptor) { if (tuple_descriptor.empty()) return; - BeginPacket(NetworkMessageType::ROW_DESCRIPTION) - .AppendValue(tuple_descriptor.size()); + BeginPacket(NetworkMessageType::ROW_DESCRIPTION); + AppendValue(tuple_descriptor.size()); for (auto &col : tuple_descriptor) { - AppendString(std::get<0>(col))() + AppendString(std::get<0>(col)); + // TODO: Table Oid (int32) + AppendValue(0); + // TODO: Attr id of column (int16) + AppendValue(0); + // Field data type (int32) + AppendValue(std::get<1>(col)); + // Data type size (int16) + AppendValue(std::get<2>(col)); + // Type modifier (int32) + AppendValue(-1); + AppendValue(0); + } + EndPacket(); + } + + inline void WriteDataRows(const std::vector &results, + size_t num_columns) { + if (results.empty() || num_columns == 0) return; + size_t num_rows = results.size() / num_columns; + for (size_t i = 0; i < num_rows; i++) { + BeginPacket(NetworkMessageType::DATA_ROW) + .AppendValue(num_columns); + for (size_t j = 0; j < num_columns; j++) { + auto content = results[i * num_columns + j]; + if (content.empty()) + AppendValue(NULL_CONTENT_SIZE); + else + AppendValue(content.size()) + .AppendString(content); + + } + EndPacket(); } } diff --git a/src/include/network/protocol_handler.h b/src/include/network/protocol_handler.h deleted file mode 100644 index 785626e4b16..00000000000 --- a/src/include/network/protocol_handler.h +++ /dev/null @@ -1,60 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// protocol_handler.h -// -// Identification: src/include/network/protocol_handler.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "common/internal_types.h" -#include "marshal.h" -#include "traffic_cop/traffic_cop.h" -// Packet content macros - -namespace peloton { -namespace network { - -typedef std::vector> ResponseBuffer; - -class ProtocolHandler { - public: - ProtocolHandler(tcop::TrafficCop *traffic_cop); - - virtual ~ProtocolHandler(); - - // TODO(Tianyi) Move thread_id to traffic_cop - // TODO(Tianyi) Make wbuf as an parameter here - /** - * Main switch case wrapper to process every packet apart from the startup - * packet. Avoid flushing the response for extended protocols. - */ - virtual ProcessResult Process(ReadBuffer &rbuf, size_t thread_id); - - virtual void Reset(); - - virtual void GetResult(); - - void SetFlushFlag(bool flush) { force_flush_ = flush; } - - bool GetFlushFlag() { return force_flush_; } - - bool force_flush_ = false; - - // TODO declare a response buffer pool so that we can reuse the responses - // so that we don't have to new packet each time - ResponseBuffer responses_; - - InputPacket request_; // Used for reading a single request - - // The traffic cop used for this connection - tcop::TrafficCop *traffic_cop_; -}; - -} // namespace network -} // namespace peloton diff --git a/src/include/network/protocol_handler_factory.h b/src/include/network/protocol_handler_factory.h deleted file mode 100644 index c13cca250b2..00000000000 --- a/src/include/network/protocol_handler_factory.h +++ /dev/null @@ -1,36 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// protocol_handler_factory.h -// -// Identification: src/include/network/protocol_handler_factory.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include - -#include "network/protocol_handler.h" - -// Packet content macros - -namespace peloton { - -namespace network { - -enum class ProtocolHandlerType { - Postgres, -}; - -// The factory of ProtocolHandler -class ProtocolHandlerFactory { - public: - static std::unique_ptr CreateProtocolHandler( - ProtocolHandlerType type, tcop::TrafficCop *trafficCop); -}; -} // namespace network -} // namespace peloton diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 3896f0384ef..4b2f5bafdf9 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -20,16 +20,12 @@ namespace network { class ProtocolInterpreter { public: - ProtocolInterpreter(size_t thread_id) : thread_id_(thread_id) {} - virtual Transition Process(std::shared_ptr in, std::shared_ptr out, CallbackFunc callback) = 0; // TODO(Tianyu): Do we really need this crap? virtual void GetResult() = 0; - protected: - size_t thread_id_; }; } // namespace network diff --git a/src/include/traffic_cop/client_transaction_handle.h b/src/include/traffic_cop/client_transaction_handle.h deleted file mode 100644 index 2079494b37d..00000000000 --- a/src/include/traffic_cop/client_transaction_handle.h +++ /dev/null @@ -1,230 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// client_transaction_handle.h -// -// Identification: src/include/traffic_cop/client_transaction_handle.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "concurrency/transaction_manager_factory.h" -namespace peloton { -namespace tcop { - -using TxnContext = concurrency::TransactionContext; -using TxnManagerFactory = concurrency::TransactionManagerFactory; - -enum class TransactionState{ - IDLE = 0, - STARTED, - FAILING, - ABORTING, -}; - -class ClientTxnHandle; - -/** - * abtract class to provide a unified interface of txn handling - */ -class AbtractClientTxnHandler { - public: - /** - * @brief Start a txn if there is no txn at the moment this function is called. - * @param thread_id number to generate epoch id in a distributed manner - * @param handle Client transaction context - * @return Current trancation that is started - * @throw TransactionException when no txn can be started (eg. current txn is failed already) - */ - virtual TxnContext *ImplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle) = 0; - - /** - * @brief Force starting a txn - * @param thread_id number to generate epoch id in a distributed manner - * @param handle Client transaction context - * @return Current trancation that is started - * @throw TransactionException when no txn can be started (eg. there is already an txn) - */ - virtual TxnContext *ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle) = 0; - - /** - * @brief Implicitly end a txn - * @param handle - * @param handle Client transaction context - */ - virtual void End(ClientTxnHandle &handle) = 0; - - /** - * @brief Explicitly commit a txn - * @param handle Client transaction context - * @throw TransactionException when there is no txn started - */ - virtual bool Commit(ClientTxnHandle &handle) = 0; - - /** - * @brief Explicitly abort a txn - * @param handle Client transaction context - */ - virtual void Abort(ClientTxnHandle &handle) = 0; - -}; - -/** - * Client Transaction handler for Transaction Handler when in Single-Statement Mode - */ -class SingleStmtClientTxnHandler : AbtractClientTxnHandler{ - - /** - * @see AbstractClientTxnHandler - */ - TxnContext *ImplicitBegin(const size_t thread_id = 0, ClientTxnHandle &handle); - - /** - * @brief This function should never be called in this mode - */ - inline TxnContext *ExplicitBegin(const size_t, ClientTxnHandle &) { - throw TransactionException("Should not be called"); - } - - /** - * @see AbstractClientTxnHandler - */ - void End(ClientTxnHandle &handle); - - /** - * @brief This function should never be called in this mode - */ - inline bool Commit(ClientTxnHandle &handle) { - throw TransactionException("Should not be called"); - } - - /** - * @see AbstractClientTxnHandler - */ - void Abort(ClientTxnHandle &handle); - -}; - -class MultiStmtsClientTxnHandler : AbtractClientTxnHandler { - - /** - * @see AbstractClientTxnHandler - */ - TxnContext *ImplicitBegin(const size_t, ClientTxnHandle &handle_); - - /** - * @see AbstractClientTxnHandler - */ - TxnContext *ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle & handle); - - /** - * @see AbstractClientTxnHandler - */ - inline void End(ClientTxnHandle &handle) {} - - /** - * @see AbstractClientTxnHandler - */ - bool Commit(ClientTxnHandle &handle); - - /** - * @see AbstractClientTxnHandler - */ - void Abort(ClientTxnHandle &handle); -}; - -/** - * @brief Wrapper class that could provide functions to properly start and end a transaction. - * - * It would operate in either Single-Statement or Multi-Statements mode, using different handler - * - * Serveral general patterns of function calls is : - * 1. - * ImplicitBegin()[By Prepare] -> ImplicitBegin() [By Execute]-> ImplicitEnd(); - * 2. - * ImplicitBegin() -> ExplicitBegin() -> ImplictBegin()[second query] - * ->ImplicitEnd()[second query] -> ExplicitCommit() -> ImplicitEnd(); - * 3. - * ImplicitBegin() -> ExplicitBegin() -> ImplictBegin()[second query] - * ->ImplicitENd()[second query]-> ExplicitAbort() -> ImplicitEnd(); - */ -class ClientTxnHandle { - friend class AbtractClientTxnHandler; - friend class SingleStmtClientTxnHandler; - friend class MultiStmtsClientTxnHandler; - - public: - - /** - * Start a transaction if there is no transaction - * @param thread_id number to generate epoch id in a distributed manner - * @return transaction context - */ - TxnContext *ImplicitBegin(const size_t thread_id = 0); - - /** - * Force starting a transaction if there is no transaction - * @param thread_id number to generate epoch id in a distributed manner - * @return transaction context - */ - TxnContext *ExplicitBegin(const size_t thread_id = 0); - - /** - * Commit/Abort a transaction and do the necessary cleanup - */ - void ImplicitEnd(); - - /** - * Explicitly commit a transaction - * @return if the commit is successful - */ - bool ExplicitCommit(); - - /** - * Explicitly abort a transaction - */ - void ExplicitAbort(); - - /** - * @brief Getter function of txn state - * @return current trancation state - */ - inline TransactionState GetTxnState() { - return txn_state_; - } - - /** - * @brief Getter function of current transaction context - * @return current transaction context - */ - inline TxnContext *GetTxn() { - return txn_; - } - - private: - - TransactionState txn_state_; - - TxnContext *txn_; - - bool single_stmt_handler_ = true; - - std::unique_ptr handler_; - - inline void ChangeToSingleStmtHandler() { - handler_ = std::unique_ptr(new SingleStmtClientTxnHandler()); - single_stmt_handler_ = true; - } - - inline void ChangeToMultiStmtsHandler() { - handler_ = std::unique_ptr(new MultiStmtsClientTxnHandler()); - single_stmt_handler_ = false; - } -}; - -} -} \ No newline at end of file diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index ed025e8e0a0..be3a6007904 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -12,12 +12,12 @@ #pragma once #include +#include "catalog/column.h" #include "executor/plan_executor.h" #include "optimizer/abstract_optimizer.h" #include "parser/postgresparser.h" #include "parser/sql_statement.h" #include "common/statement_cache.h" -#include "client_transaction_handle.h" namespace peloton { namespace tcop { @@ -48,54 +48,80 @@ struct ClientProcessState { QueryType skipped_query_type_ = QueryType::QUERY_INVALID; StatementCache statement_cache_; int rows_affected_; + executor::ExecutionResult p_status_; + + // TODO(Tianyu): This is vile, get rid of this + TcopTxnState &GetCurrentTxnState() { + if (tcop_txn_state_.empty()) { + static TcopTxnState + default_state = std::make_pair(nullptr, ResultType::INVALID); + return default_state; + } + return tcop_txn_state_.top(); + } }; -inline std::unique_ptr ParseQuery(const std::string &query_string) { - auto &peloton_parser = parser::PostgresParser::GetInstance(); - auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); - // When the query is empty(such as ";" or ";;", still valid), - // the parse tree is empty, parser will return nullptr. - if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) - throw ParserException("Error Parsing SQL statement"); - return sql_stmt_list; -} +// TODO(Tianyu): We use an instance here in expectation that instance variables +// such as parser or others will be here when we refactor singletons, but Tcop +// should not have any Client specific states. +class Tcop { + public: + // TODO(Tianyu): Remove later + inline static Tcop &GetInstance() { + static Tcop tcop; + return tcop; + } + + inline std::unique_ptr ParseQuery(const std::string &query_string) { + auto &peloton_parser = parser::PostgresParser::GetInstance(); + auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); + // When the query is empty(such as ";" or ";;", still valid), + // the parse tree is empty, parser will return nullptr. + if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) + throw ParserException("Error Parsing SQL statement"); + return sql_stmt_list; + } -std::shared_ptr PrepareStatement(ClientProcessState &state, - const std::string &statement_name, - const std::string &query_string, - std::unique_ptr &&sql_stmt_list); + std::shared_ptr PrepareStatement(ClientProcessState &state, + const std::string &statement_name, + const std::string &query_string, + std::unique_ptr &&sql_stmt_list); -ResultType ExecuteStatement( - ClientProcessState &state, - CallbackFunc callback); + ResultType ExecuteStatement( + ClientProcessState &state, + CallbackFunc callback); -void ExecuteHelper( - ClientProcessState &state, - std::vector &result, - concurrency::TransactionContext *txn, - CallbackFunc callback); + bool BindParamsForCachePlan( + ClientProcessState &state, + const std::vector> &); -bool BindParamsForCachePlan( - ClientProcessState &state, - const std::vector> &); + std::vector GenerateTupleDescriptor(ClientProcessState &state, + parser::SQLStatement *select_stmt); -std::vector GenerateTupleDescriptor( - parser::SQLStatement *select_stmt); + static FieldInfo GetColumnFieldForValueType(std::string column_name, + type::TypeId column_type); -FieldInfo GetColumnFieldForValueType(std::string column_name, - type::TypeId column_type); + // Get all data tables from a TableRef. + // For multi-way join + // TODO(Bowei) still a HACK + void GetTableColumns(ClientProcessState &state, + parser::TableRef *from_table, + std::vector &target_columns); -void ExecuteStatementPlanGetResult(); + void ExecuteStatementPlanGetResult(ClientProcessState &state); -ResultType ExecuteStatementGetResult(); + ResultType ExecuteStatementGetResult(ClientProcessState &state); -void ProcessInvalidStatement(ClientProcessState &state); + void ProcessInvalidStatement(ClientProcessState &state); -// Get all data tables from a TableRef. -// For multi-way join -// still a HACK -void GetTableColumns(parser::TableRef *from_table, - std::vector &target_tables); + private: + ResultType CommitQueryHelper(ClientProcessState &state); + ResultType BeginQueryHelper(ClientProcessState &state); + ResultType AbortQueryHelper(ClientProcessState &state); + executor::ExecutionResult ExecuteHelper(ClientProcessState &state, + CallbackFunc callback); + +}; } // namespace tcop } // namespace peloton \ No newline at end of file diff --git a/src/include/traffic_cop/traffic_cop.h b/src/include/traffic_cop/traffic_cop.h deleted file mode 100644 index e324b87fe82..00000000000 --- a/src/include/traffic_cop/traffic_cop.h +++ /dev/null @@ -1,202 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// traffic_cop.h -// -// Identification: src/include/traffic_cop/traffic_cop.h -// -// Copyright (c) 2015-17, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include -#include - -// Libevent 2.0 -#include "event.h" - -#include "catalog/column.h" -#include "common/internal_types.h" -#include "common/portal.h" -#include "common/statement.h" -#include "executor/plan_executor.h" -#include "optimizer/abstract_optimizer.h" -#include "parser/sql_statement.h" -#include "type/type.h" - -namespace peloton { - -namespace concurrency { -class TransactionContext; -} // namespace concurrency - -namespace tcop { - -//===--------------------------------------------------------------------===// -// TRAFFIC COP -// Helpers for executing statements. -// -// Usage in unit tests: -// auto &traffic_cop = tcop::TrafficCop::GetInstance(); -// traffic_cop.SetTaskCallback(, ); -// txn = txn_manager.BeginTransaction(); -// traffic_cop.SetTcopTxnState(txn); -// std::shared_ptr plan = ; -// traffic_cop.ExecuteHelper(plan, , , ); -// -// traffic_cop.CommitQueryHelper(); -//===--------------------------------------------------------------------===// - -class TrafficCop { - public: - TrafficCop(); - TrafficCop(void (*task_callback)(void *), void *task_callback_arg); - ~TrafficCop(); - DISALLOW_COPY_AND_MOVE(TrafficCop); - - // Static singleton used by unit tests. - static TrafficCop &GetInstance(); - - // Reset this object. - void Reset(); - - // Execute a statement - ResultType ExecuteStatement( - const std::shared_ptr &statement, - const std::vector ¶ms, const bool unnamed, - std::shared_ptr param_stats, - const std::vector &result_format, std::vector &result, - size_t thread_id = 0); - - // Helper to handle txn-specifics for the plan-tree of a statement. - executor::ExecutionResult ExecuteHelper( - std::shared_ptr plan, - const std::vector ¶ms, std::vector &result, - const std::vector &result_format, size_t thread_id = 0); - - // Prepare a statement using the parse tree - std::shared_ptr PrepareStatement( - const std::string &statement_name, const std::string &query_string, - std::unique_ptr sql_stmt_list, - size_t thread_id = 0); - - bool BindParamsForCachePlan( - const std::vector> &, - const size_t thread_id = 0); - - std::vector GenerateTupleDescriptor( - parser::SQLStatement *select_stmt); - - FieldInfo GetColumnFieldForValueType(std::string column_name, - type::TypeId column_type); - - void SetTcopTxnState(concurrency::TransactionContext *txn) { - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - - ResultType CommitQueryHelper(); - - void ExecuteStatementPlanGetResult(); - - ResultType ExecuteStatementGetResult(); - - void SetTaskCallback(void (*task_callback)(void *), void *task_callback_arg) { - task_callback_ = task_callback; - task_callback_arg_ = task_callback_arg; - } - - void setRowsAffected(int rows_affected) { rows_affected_ = rows_affected; } - - void ProcessInvalidStatement(); - - int getRowsAffected() { return rows_affected_; } - - void SetStatement(std::shared_ptr statement) { - statement_ = std::move(statement); - } - - std::shared_ptr GetStatement() { return statement_; } - - void SetResult(std::vector result) { - result_ = std::move(result); - } - - std::vector &GetResult() { return result_; } - - void SetParamVal(std::vector param_values) { - param_values_ = std::move(param_values); - } - - std::vector &GetParamVal() { return param_values_; } - - std::string &GetErrorMessage() { return error_message_; } - - void SetQueuing(bool is_queuing) { is_queuing_ = is_queuing; } - - bool GetQueuing() { return is_queuing_; } - - executor::ExecutionResult p_status_; - - void SetDefaultDatabaseName(std::string default_database_name) { - default_database_name_ = std::move(default_database_name); - } - - // TODO: this member variable should be in statement_ after parser part - // finished - std::string query_; - - private: - bool is_queuing_; - - std::string error_message_; - - std::vector param_values_; - - std::vector results_; - - // This save currnet statement in the traffic cop - std::shared_ptr statement_; - - // Default database name - std::string default_database_name_ = DEFAULT_DB_NAME; - - int rows_affected_; - - // The optimizer used for this connection - std::unique_ptr optimizer_; - - // flag of single statement txn - bool single_statement_txn_; - - std::vector result_; - - // The current callback to be invoked after execution completes. - void (*task_callback_)(void *); - void *task_callback_arg_; - - // pair of txn ptr and the result so-far for that txn - // use a stack to support nested-txns - using TcopTxnState = std::pair; - std::stack tcop_txn_state_; - - static TcopTxnState &GetDefaultTxnState(); - - TcopTxnState &GetCurrentTxnState(); - - ResultType BeginQueryHelper(size_t thread_id); - - ResultType AbortQueryHelper(); - - // Get all data tables from a TableRef. - // For multi-way join - // still a HACK - void GetTableColumns(parser::TableRef *from_table, - std::vector &target_tables); -}; - -} // namespace tcop -} // namespace peloton diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index 3f31ec4fea0..1b453ac702d 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -17,8 +17,6 @@ #include "network/connection_handle.h" #include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" -#include "network/postgres_protocol_handler.h" -#include "network/protocol_handler_factory.h" #include "common/utility.h" #include "settings/settings_manager.h" diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index f2c6094fa4a..80aff5406d5 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -15,6 +15,7 @@ #include "network/postgres_network_commands.h" #include "traffic_cop/tcop.h" #include "settings/settings_manager.h" +#include "planner/abstract_plan.h" #define SSL_MESSAGE_VERNO 80877103 #define PROTO_MAJOR_VERSION(x) ((x) >> 16) @@ -71,8 +72,7 @@ void PostgresNetworkCommand::ReadParamValues(std::vector &bind_pa param_types[i], param_len); break; - default: - throw NetworkProcessException("Unexpected format code"); + default:throw NetworkProcessException("Unexpected format code"); } } } @@ -218,9 +218,9 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, LOG_TRACE("Execute query: %s", query.c_str()); std::unique_ptr sql_stmt_list; try { - sql_stmt_list = tcop::ParseQuery(query); + sql_stmt_list = tcop::Tcop::GetInstance().ParseQuery(query); } catch (Exception &e) { - tcop::ProcessInvalidStatement(state); + tcop::Tcop::GetInstance().ProcessInvalidStatement(state); out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); @@ -232,7 +232,7 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - + // TODO(Yuchen): Hack. We only process the first statement in the packet now. // We should store the rest of statements that will not be processed right // away. For the hack, in most cases, it works. Because for example in psql, @@ -243,17 +243,16 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, QueryType query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); - interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; - + switch (query_type) { case QueryType::QUERY_PREPARE: { std::shared_ptr statement(nullptr); auto prep_stmt = dynamic_cast(sql_stmt.get()); std::string stmt_name = prep_stmt->name; - statement = tcop::PrepareStatement(state, - stmt_name, - query, - std::move(prep_stmt->query)); + statement = tcop::Tcop::GetInstance().PrepareStatement(state, + stmt_name, + query, + std::move(prep_stmt->query)); if (statement == nullptr) { out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); @@ -292,23 +291,24 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, std::vector(state.statement_->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); - if (!tcop::BindParamsForCachePlan(state, exec_stmt->parameters)) { - tcop::ProcessInvalidStatement(state); + if (!tcop::Tcop::GetInstance().BindParamsForCachePlan(state, + exec_stmt->parameters)) { + tcop::Tcop::GetInstance().ProcessInvalidStatement(state); out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); return Transition::PROCEED; } - auto status = tcop::ExecuteStatement(state, callback); + auto status = tcop::Tcop::GetInstance().ExecuteStatement(state, callback); if (state.is_queuing_) return Transition::NEED_RESULT; - interpreter.ExecQueryMessageGetResult(status); + interpreter.ExecQueryMessageGetResult(out, status); return Transition::PROCEED; }; case QueryType::QUERY_EXPLAIN: { auto status = interpreter.ExecQueryExplain(query, dynamic_cast(*sql_stmt)); - interpreter.ExecQueryMessageGetResult(status); + interpreter.ExecQueryMessageGetResult(out, status); return Transition::PROCEED; } default: { @@ -316,10 +316,11 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, std::unique_ptr unnamed_sql_stmt_list( new parser::SQLStatementList()); unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); - state.statement_ = tcop::PrepareStatement(state, - stmt_name, - query, - std::move(unnamed_sql_stmt_list)); + state.statement_ = tcop::Tcop::GetInstance().PrepareStatement(state, + stmt_name, + query, + std::move( + unnamed_sql_stmt_list)); if (state.statement_ == nullptr) { out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); @@ -331,10 +332,10 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, std::vector(state.statement_->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); auto status = - tcop::ExecuteStatement(state, callback); + tcop::Tcop::GetInstance().ExecuteStatement(state, callback); if (state.is_queuing_) return Transition::NEED_RESULT; - interpreter.ExecQueryMessageGetResult(status); + interpreter.ExecQueryMessageGetResult(out, status); return Transition::PROCEED; } } @@ -352,9 +353,9 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, std::unique_ptr sql_stmt_list; QueryType query_type = QueryType::QUERY_OTHER; try { - sql_stmt_list = tcop::ParseQuery(query); + sql_stmt_list = tcop::Tcop::GetInstance().ParseQuery(query); } catch (Exception &e) { - tcop::ProcessInvalidStatement(state); + tcop::Tcop::GetInstance().ProcessInvalidStatement(state); state.skipped_stmt_ = true; out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); @@ -380,12 +381,13 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } - auto statement = tcop::PrepareStatement(state, - statement_name, - query, - std::move(sql_stmt_list)); + auto statement = tcop::Tcop::GetInstance().PrepareStatement(state, + statement_name, + query, + std::move( + sql_stmt_list)); if (statement == nullptr) { - tcop::ProcessInvalidStatement(state); + tcop::Tcop::GetInstance().ProcessInvalidStatement(state); state.skipped_stmt_ = true; out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, state.error_message_}}); @@ -415,7 +417,7 @@ Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } - std::vector formats = ReadParamFormats(); + std::vector formats = ReadParamFormats(); // Get statement info generated in PARSE message std::shared_ptr @@ -455,7 +457,7 @@ Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, // Group the parameter types and the parameters in this vector std::vector> bind_parameters; std::vector param_values; - + auto param_types = statement->GetParamTypes(); ReadParamValues(bind_parameters, param_values, param_types, formats); state.result_format_ = @@ -483,12 +485,12 @@ Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, out.WriteSingleTypePacket(NetworkMessageType::NO_DATA_RESPONSE); return Transition::PROCEED; } - + auto mode = in_->ReadValue(); std::string portal_name = in_->ReadString(); switch (mode) { - case PostgresNetworkObjectType::PORTAL:LOG_TRACE("Describe a portal"); - + case PostgresNetworkObjectType::PORTAL: { + LOG_TRACE("Describe a portal"); auto portal_itr = interpreter.portals_.find(portal_name); // TODO: error handling here // Ahmed: This is causing the continuously running thread @@ -501,11 +503,13 @@ Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, } else out.WriteTupleDescriptor(portal_itr->second->GetStatement()->GetTupleDescriptor()); break; + } case PostgresNetworkObjectType::STATEMENT: // TODO(Tianyu): Do we not support this or something? LOG_TRACE("Describe a prepared statement"); break; - default:throw NetworkProcessException("Unexpected Describe type"); + default: + throw NetworkProcessException("Unexpected Describe type"); } return Transition::PROCEED; } @@ -514,7 +518,6 @@ Transition ExecuteCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc callback) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); - interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_JDBC; std::string portal_name = in_->ReadString(); // covers weird JDBC edge case of sending double BEGIN statements. Don't @@ -536,16 +539,15 @@ Transition ExecuteCommand::Exec(PostgresProtocolInterpreter &interpreter, std::shared_ptr portal = portal_itr->second; state.statement_ = portal->GetStatement(); - auto param_stat = portal->GetParamStat(); if (state.statement_ == nullptr) throw NetworkProcessException( "Did not find statement in portal: " + portal_name); state.param_values_ = portal->GetParameters(); - auto status = tcop::ExecuteStatement(state, callback); + auto status = tcop::Tcop::GetInstance().ExecuteStatement(state, callback); if (state.is_queuing_) return Transition::NEED_RESULT; - interpreter.ExecExecuteMessageGetResult(status); + interpreter.ExecExecuteMessageGetResult(out, status); return Transition::PROCEED; } @@ -569,13 +571,12 @@ Transition CloseCommand::Exec(PostgresProtocolInterpreter &interpreter, state.statement_cache_.DeleteStatement(name); break; } - case 'P': { + case PostgresNetworkObjectType::PORTAL: { LOG_TRACE("Deleting portal %s from cache", name.c_str()); auto portal_itr = interpreter.portals_.find(name); - if (portal_itr != interpreter.portals_.end()) { + if (portal_itr != interpreter.portals_.end()) // delete portal if it exists interpreter.portals_.erase(portal_itr); - } break; } default: @@ -584,11 +585,12 @@ Transition CloseCommand::Exec(PostgresProtocolInterpreter &interpreter, } // Send close complete response out.WriteSingleTypePacket(NetworkMessageType::CLOSE_COMPLETE); + return Transition::PROCEED; } -Transition TerminateCommand(PostgresProtocolInterpreter &, - PostgresPacketWriter &, - CallbackFunc) { +Transition TerminateCommand::Exec(PostgresProtocolInterpreter &, + PostgresPacketWriter &, + CallbackFunc) { return Transition::TERMINATE; } } // namespace network diff --git a/src/network/postgres_protocol_handler.cpp b/src/network/postgres_protocol_handler.cpp deleted file mode 100644 index ee85e6136af..00000000000 --- a/src/network/postgres_protocol_handler.cpp +++ /dev/null @@ -1,1272 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// postgres_protocol_handler.cpp -// -// Identification: src/network/postgres_protocol_handler.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include -#include -#include - -#include "common/cache.h" -#include "common/internal_types.h" -#include "common/macros.h" -#include "common/portal.h" -#include "expression/expression_util.h" -#include "network/marshal.h" -#include "network/peloton_server.h" -#include "network/postgres_protocol_handler.h" -#include "parser/postgresparser.h" -#include "parser/statements.h" -#include "planner/plan_util.h" -#include "settings/settings_manager.h" -#include "traffic_cop/traffic_cop.h" -#include "type/value.h" -#include "type/value_factory.h" -#include "util/string_util.h" - -#define SSL_MESSAGE_VERNO 80877103 -#define PROTO_MAJOR_VERSION(x) ((x) >> 16) - -namespace peloton { -namespace network { - -// TODO: Remove hardcoded auth strings -// Hardcoded authentication strings used during session startup. To be removed -const std::unordered_map - // clang-format off - PostgresProtocolHandler::parameter_status_map_ = - boost::assign::map_list_of("application_name", "psql") - ("client_encoding", "UTF8") - ("DateStyle", "ISO, MDY") - ("integer_datetimes", "on") - ("IntervalStyle", "postgres") - ("is_superuser", "on") - ("server_encoding", "UTF8") - ("server_version", "9.5devel") - ("session_authorization", "postgres") - ("standard_conforming_strings", "on") - ("TimeZone", "US/Eastern"); -// clang-format on - -PostgresProtocolHandler::PostgresProtocolHandler(tcop::TrafficCop *traffic_cop) - : ProtocolHandler(traffic_cop), - init_stage_(true), - txn_state_(NetworkTransactionStateType::IDLE) {} - -PostgresProtocolHandler::~PostgresProtocolHandler() {} - -void PostgresProtocolHandler::SendStartupResponse() { - std::unique_ptr response(new OutputPacket()); - - // send auth-ok ('R') - response->msg_type = NetworkMessageType::AUTHENTICATION_REQUEST; - PacketPutInt(response.get(), 0, 4); - responses_.push_back(std::move(response)); - - // Send the parameterStatus map ('S') - for (auto it = parameter_status_map_.begin(); - it != parameter_status_map_.end(); it++) { - MakeHardcodedParameterStatus(*it); - } - - // ready-for-query packet -> 'Z' - SendReadyForQuery(NetworkTransactionStateType::IDLE); - - // we need to send the response right away - SetFlushFlag(true); -} - -bool PostgresProtocolHandler::HardcodedExecuteFilter(QueryType query_type) { - switch (query_type) { - // Skip SET - case QueryType::QUERY_SET: - case QueryType::QUERY_SHOW: - return false; - // Skip duplicate BEGIN - case QueryType::QUERY_BEGIN: - if (txn_state_ == NetworkTransactionStateType::BLOCK) { - return false; - } - break; - // Skip duplicate Commits and Rollbacks - case QueryType::QUERY_COMMIT: - case QueryType::QUERY_ROLLBACK: - if (txn_state_ == NetworkTransactionStateType::IDLE) { - return false; - } - default: - break; - } - return true; -} - -// The Simple Query Protocol -ProcessResult PostgresProtocolHandler::ExecQueryMessage( - InputPacket *pkt, const size_t thread_id) { - std::string query; - std::string error_message; - PacketGetString(pkt, pkt->len, query); - LOG_TRACE("Execute query: %s", query.c_str()); - std::unique_ptr sql_stmt_list; - try { - auto &peloton_parser = parser::PostgresParser::GetInstance(); - sql_stmt_list = peloton_parser.BuildParseTree(query); - - // When the query is empty(such as ";" or ";;", still valid), - // the pare tree is empty, parser will return nullptr. - if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { - throw ParserException("Error Parsing SQL statement"); - } - } // If the statement is invalid or not supported yet - catch (Exception &e) { - traffic_cop_->ProcessInvalidStatement(); - error_message = e.what(); - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - - if (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0) { - SendEmptyQueryResponse(); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - - // TODO(Yuchen): Hack. We only process the first statement in the packet now. - // We should store the rest of statements that will not be processed right - // away. For the hack, in most cases, it works. Because for example in psql, - // one packet contains only one query. But when using the pipeline mode in - // Libpqxx, it sends multiple query in one packet. In this case, it's - // incorrect. - auto sql_stmt = sql_stmt_list->PassOutStatement(0); - - QueryType query_type = - StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt.get()); - protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; - - switch (query_type) { - case QueryType::QUERY_PREPARE: { - std::shared_ptr statement(nullptr); - auto prep_stmt = dynamic_cast(sql_stmt.get()); - std::string stmt_name = prep_stmt->name; - statement = traffic_cop_->PrepareStatement(stmt_name, query, - std::move(prep_stmt->query)); - if (statement.get() == nullptr) { - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - statement_cache_.AddStatement(statement); - - CompleteCommand(query_type, 0); - - // PAVLO: 2017-01-15 - // There used to be code here that would invoke this method passing - // in NetworkMessageType::READY_FOR_QUERY as the argument. But when - // I switched to strong types, this obviously doesn't work. So I - // switched it to be NetworkTransactionStateType::IDLE. I don't know - // we just don't always send back the internal txn state? - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - }; - case QueryType::QUERY_EXECUTE: { - std::vector param_values; - parser::ExecuteStatement *exec_stmt = - static_cast(sql_stmt.get()); - std::string stmt_name = exec_stmt->name; - - auto cached_statement = statement_cache_.GetStatement(stmt_name); - if (cached_statement.get() != nullptr) { - traffic_cop_->SetStatement(cached_statement); - } - // Did not find statement with same name - else { - std::string error_message = "The prepared statement does not exist"; - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - std::vector result_format( - traffic_cop_->GetStatement()->GetTupleDescriptor().size(), 0); - result_format_ = result_format; - - if (!traffic_cop_->BindParamsForCachePlan(exec_stmt->parameters)) { - traffic_cop_->ProcessInvalidStatement(); - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - - bool unnamed = false; - auto status = traffic_cop_->ExecuteStatement( - traffic_cop_->GetStatement(), traffic_cop_->GetParamVal(), unnamed, - nullptr, result_format_, traffic_cop_->GetResult(), thread_id); - if (traffic_cop_->GetQueuing()) { - return ProcessResult::PROCESSING; - } - ExecQueryMessageGetResult(status); - return ProcessResult::COMPLETE; - }; - case QueryType::QUERY_EXPLAIN: { - auto status = ExecQueryExplain( - query, static_cast(*sql_stmt)); - ExecQueryMessageGetResult(status); - return ProcessResult::COMPLETE; - } - default: { - std::string stmt_name = "unamed"; - std::unique_ptr unnamed_sql_stmt_list( - new parser::SQLStatementList()); - unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); - traffic_cop_->SetStatement(traffic_cop_->PrepareStatement( - stmt_name, query, std::move(unnamed_sql_stmt_list))); - if (traffic_cop_->GetStatement().get() == nullptr) { - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return ProcessResult::COMPLETE; - } - traffic_cop_->SetParamVal(std::vector()); - bool unnamed = false; - result_format_ = std::vector( - traffic_cop_->GetStatement()->GetTupleDescriptor().size(), 0); - auto status = traffic_cop_->ExecuteStatement( - traffic_cop_->GetStatement(), traffic_cop_->GetParamVal(), unnamed, - nullptr, result_format_, traffic_cop_->GetResult(), thread_id); - if (traffic_cop_->GetQueuing()) { - return ProcessResult::PROCESSING; - } - ExecQueryMessageGetResult(status); - return ProcessResult::COMPLETE; - } - } -} - -ResultType PostgresProtocolHandler::ExecQueryExplain( - const std::string &query, parser::ExplainStatement &explain_stmt) { - std::unique_ptr unnamed_sql_stmt_list( - new parser::SQLStatementList()); - unnamed_sql_stmt_list->PassInStatement(std::move(explain_stmt.real_sql_stmt)); - auto stmt = traffic_cop_->PrepareStatement("explain", query, - std::move(unnamed_sql_stmt_list)); - ResultType status; - if (stmt != nullptr) { - traffic_cop_->SetStatement(stmt); - std::vector plan_info = StringUtil::Split( - planner::PlanUtil::GetInfo(stmt->GetPlanTree().get()), '\n'); - const std::vector tuple_descriptor = { - traffic_cop_->GetColumnFieldForValueType("Query plan", - type::TypeId::VARCHAR)}; - stmt->SetTupleDescriptor(tuple_descriptor); - traffic_cop_->SetResult(plan_info); - status = ResultType::SUCCESS; - } else { - status = ResultType::FAILURE; - } - return status; -} - -void PostgresProtocolHandler::ExecQueryMessageGetResult(ResultType status) { - std::vector tuple_descriptor; - if (status == ResultType::SUCCESS) { - tuple_descriptor = traffic_cop_->GetStatement()->GetTupleDescriptor(); - } else if (status == ResultType::FAILURE) { // check status - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return; - } else if (status == ResultType::TO_ABORT) { - std::string error_message = - "current transaction is aborted, commands ignored until end of " - "transaction block"; - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return; - } - - // send the attribute names - PutTupleDescriptor(tuple_descriptor); - - // send the result rows - SendDataRows(traffic_cop_->GetResult(), tuple_descriptor.size()); - - CompleteCommand(traffic_cop_->GetStatement()->GetQueryType(), - traffic_cop_->getRowsAffected()); - - SendReadyForQuery(NetworkTransactionStateType::IDLE); -} - -/* - * exec_parse_message - handle PARSE message - */ -void PostgresProtocolHandler::ExecParseMessage(InputPacket *pkt) { - std::string statement_name, query, query_type_string; - GetStringToken(pkt, statement_name); - GetStringToken(pkt, query); - - // In JDBC, one query starts with parsing stage. - // Reset skipped_stmt_ to false for the new query. - skipped_stmt_ = false; - std::unique_ptr sql_stmt_list; - QueryType query_type = QueryType::QUERY_OTHER; - try { - LOG_TRACE("%s, %s", statement_name.c_str(), query.c_str()); - auto &peloton_parser = parser::PostgresParser::GetInstance(); - sql_stmt_list = peloton_parser.BuildParseTree(query); - if (sql_stmt_list.get() != nullptr && !sql_stmt_list->is_valid) { - throw ParserException("Error parsing SQL statement"); - } - } catch (Exception &e) { - traffic_cop_->ProcessInvalidStatement(); - skipped_stmt_ = true; - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, e.what()}}); - return; - } - - // If the query is not supported yet, - // we will skip the rest commands (B,E,..) for this query - // For empty query, we still want to get it constructed - // TODO (Tianyi) Consider handle more statement - bool empty = (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0); - if (!empty) { - parser::SQLStatement *sql_stmt = sql_stmt_list->GetStatement(0); - query_type = StatementTypeToQueryType(sql_stmt->GetType(), sql_stmt); - } - bool skip = !HardcodedExecuteFilter(query_type); - if (skip) { - skipped_stmt_ = true; - skipped_query_string_ = query; - skipped_query_type_ = query_type; - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::PARSE_COMPLETE; - responses_.push_back(std::move(response)); - return; - } - - // Prepare statement - std::shared_ptr statement(nullptr); - - statement = traffic_cop_->PrepareStatement(statement_name, query, - std::move(sql_stmt_list)); - if (statement.get() == nullptr) { - traffic_cop_->ProcessInvalidStatement(); - skipped_stmt_ = true; - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - return; - } - LOG_TRACE("PrepareStatement[%s] => %s", statement_name.c_str(), - query.c_str()); - // Read number of params - int num_params = PacketGetInt(pkt, 2); - - // Read param types - std::vector param_types(num_params); - auto type_buf_begin = pkt->Begin() + pkt->ptr; - auto type_buf_len = ReadParamType(pkt, num_params, param_types); - - // Cache the received query - bool unnamed_query = statement_name.empty(); - statement->SetParamTypes(param_types); - - // Stat - if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { - // Make a copy of param types for stat collection - stats::QueryMetric::QueryParamBuf query_type_buf; - query_type_buf.len = type_buf_len; - query_type_buf.buf = PacketCopyBytes(type_buf_begin, type_buf_len); - - // Unnamed statement - if (unnamed_query) { - unnamed_stmt_param_types_ = query_type_buf; - } else { - statement_param_types_[statement_name] = query_type_buf; - } - } - - // Cache the statement - statement_cache_.AddStatement(statement); - - // Send Parse complete response - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::PARSE_COMPLETE; - responses_.push_back(std::move(response)); -} - -void PostgresProtocolHandler::ExecBindMessage(InputPacket *pkt) { - std::string portal_name, statement_name; - // BIND message - GetStringToken(pkt, portal_name); - GetStringToken(pkt, statement_name); - - if (skipped_stmt_) { - // send bind complete - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::BIND_COMPLETE; - responses_.push_back(std::move(response)); - return; - } - - // Read parameter format - int num_params_format = PacketGetInt(pkt, 2); - std::vector formats(num_params_format); - - auto format_buf_begin = pkt->Begin() + pkt->ptr; - auto format_buf_len = ReadParamFormat(pkt, num_params_format, formats); - - int num_params = PacketGetInt(pkt, 2); - // error handling - if (num_params_format != num_params) { - std::string error_message = - "Malformed request: num_params_format is not equal to num_params"; - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - return; - } - - // Get statement info generated in PARSE message - std::shared_ptr statement; - stats::QueryMetric::QueryParamBuf param_type_buf; - - statement = statement_cache_.GetStatement(statement_name); - - if (statement.get() == nullptr) { - std::string error_message = statement_name.empty() - ? "Invalid unnamed statement" - : "The prepared statement does not exist"; - LOG_ERROR("%s", error_message.c_str()); - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - return; - } - - // Empty query - if (statement->GetQueryType() == QueryType::QUERY_INVALID) { - std::unique_ptr response(new OutputPacket()); - // Send Bind complete response - response->msg_type = NetworkMessageType::BIND_COMPLETE; - responses_.push_back(std::move(response)); - // TODO(Tianyi) This is a hack to respond correct describe message - // as well as execute message - skipped_stmt_ = true; - skipped_query_string_ = ""; - return; - } - - // UNNAMED STATEMENT - if (statement_name.empty()) { - param_type_buf = unnamed_stmt_param_types_; - // NAMED STATEMENT - } else { - param_type_buf = statement_param_types_[statement_name]; - } - - const auto &query_string = statement->GetQueryString(); - const auto &query_type = statement->GetQueryType(); - - // check if the loaded statement needs to be skipped - skipped_stmt_ = false; - if (HardcodedExecuteFilter(query_type) == false) { - skipped_stmt_ = true; - skipped_query_string_ = query_string; - std::unique_ptr response(new OutputPacket()); - // Send Bind complete response - response->msg_type = NetworkMessageType::BIND_COMPLETE; - responses_.push_back(std::move(response)); - return; - } - - // Group the parameter types and the parameters in this vector - std::vector> bind_parameters(num_params); - std::vector param_values(num_params); - - auto param_types = statement->GetParamTypes(); - - auto val_buf_begin = pkt->Begin() + pkt->ptr; - auto val_buf_len = ReadParamValue(pkt, num_params, param_types, - bind_parameters, param_values, formats); - - int format_codes_number = PacketGetInt(pkt, 2); - LOG_TRACE("format_codes_number: %d", format_codes_number); - // Set the result-column format code - if (format_codes_number == 0) { - // using the default text format - result_format_ = - std::vector(statement->GetTupleDescriptor().size(), 0); - } else if (format_codes_number == 1) { - // get the format code from packet - auto result_format = PacketGetInt(pkt, 2); - result_format_ = - std::vector(statement->GetTupleDescriptor().size(), result_format); - } else { - // get the format code for each column - result_format_.clear(); - for (int format_code_idx = 0; format_code_idx < format_codes_number; - ++format_code_idx) { - result_format_.push_back(PacketGetInt(pkt, 2)); - LOG_TRACE("format code: %d", *result_format_.rbegin()); - } - } - - if (param_values.size() > 0) { - statement->GetPlanTree()->SetParameterValues(¶m_values); - // Instead of tree traversal, we should put param values in the - // executor context. - } - - std::shared_ptr param_stat(nullptr); - if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID && - num_params > 0) { - // Make a copy of format for stat collection - stats::QueryMetric::QueryParamBuf param_format_buf; - param_format_buf.len = format_buf_len; - param_format_buf.buf = PacketCopyBytes(format_buf_begin, format_buf_len); - PELOTON_ASSERT(format_buf_len > 0); - - // Make a copy of value for stat collection - stats::QueryMetric::QueryParamBuf param_val_buf; - param_val_buf.len = val_buf_len; - param_val_buf.buf = PacketCopyBytes(val_buf_begin, val_buf_len); - PELOTON_ASSERT(val_buf_len > 0); - - param_stat.reset(new stats::QueryMetric::QueryParams( - param_format_buf, param_type_buf, param_val_buf, num_params)); - } - - // Construct a portal. - // Notice that this will move param_values so no value will be left there. - auto portal = - new Portal(portal_name, statement, std::move(param_values), param_stat); - std::shared_ptr portal_reference(portal); - - auto itr = portals_.find(portal_name); - // Found portal name in portal map - if (itr != portals_.end()) { - itr->second = portal_reference; - } - // Create a new entry in portal map - else { - portals_.insert(std::make_pair(portal_name, portal_reference)); - } - // send bind complete - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::BIND_COMPLETE; - responses_.push_back(std::move(response)); -} - -size_t PostgresProtocolHandler::ReadParamType( - InputPacket *pkt, int num_params, std::vector ¶m_types) { - auto begin = pkt->ptr; - // get the type of each parameter - for (int i = 0; i < num_params; i++) { - int param_type = PacketGetInt(pkt, 4); - param_types[i] = param_type; - } - auto end = pkt->ptr; - return end - begin; -} - -size_t PostgresProtocolHandler::ReadParamFormat(InputPacket *pkt, - int num_params_format, - std::vector &formats) { - auto begin = pkt->ptr; - // get the format of each parameter - for (int i = 0; i < num_params_format; i++) { - formats[i] = PacketGetInt(pkt, 2); - } - auto end = pkt->ptr; - return end - begin; -} - -// For consistency, this function assumes the input vectors has the correct size -size_t PostgresProtocolHandler::ReadParamValue( - InputPacket *pkt, int num_params, std::vector ¶m_types, - std::vector> &bind_parameters, - std::vector ¶m_values, std::vector &formats) { - auto begin = pkt->ptr; - ByteBuf param; - for (int param_idx = 0; param_idx < num_params; param_idx++) { - int param_len = PacketGetInt(pkt, 4); - // BIND packet NULL parameter case - if (param_len == -1) { - // NULL mode - auto peloton_type = PostgresValueTypeToPelotonValueType( - static_cast(param_types[param_idx])); - bind_parameters[param_idx] = - std::make_pair(peloton_type, std::string("")); - param_values[param_idx] = - type::ValueFactory::GetNullValueByType(peloton_type); - } else { - PacketGetBytes(pkt, param_len, param); - - if (formats[param_idx] == 0) { - // TEXT mode - std::string param_str = std::string(std::begin(param), std::end(param)); - bind_parameters[param_idx] = - std::make_pair(type::TypeId::VARCHAR, param_str); - if ((unsigned int)param_idx >= param_types.size() || - PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx]) == - type::TypeId::VARCHAR) { - param_values[param_idx] = - type::ValueFactory::GetVarcharValue(param_str); - } else { - param_values[param_idx] = - (type::ValueFactory::GetVarcharValue(param_str)) - .CastAs(PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx])); - } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != - type::TypeId::INVALID); - } else { - // BINARY mode - PostgresValueType pg_value_type = - static_cast(param_types[param_idx]); - LOG_TRACE("Postgres Protocol Conversion [param_idx=%d]", param_idx); - switch (pg_value_type) { - case PostgresValueType::TINYINT: { - int8_t int_val = 0; - for (size_t i = 0; i < sizeof(int8_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::TINYINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetTinyIntValue(int_val).Copy(); - break; - } - case PostgresValueType::SMALLINT: { - int16_t int_val = 0; - for (size_t i = 0; i < sizeof(int16_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetSmallIntValue(int_val).Copy(); - break; - } - case PostgresValueType::INTEGER: { - int32_t int_val = 0; - for (size_t i = 0; i < sizeof(int32_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::INTEGER, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetIntegerValue(int_val).Copy(); - break; - } - case PostgresValueType::BIGINT: { - int64_t int_val = 0; - for (size_t i = 0; i < sizeof(int64_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::BIGINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetBigIntValue(int_val).Copy(); - break; - } - case PostgresValueType::DOUBLE: { - double float_val = 0; - unsigned long buf = 0; - for (size_t i = 0; i < sizeof(double); ++i) { - buf = (buf << 8) | param[i]; - } - PELOTON_MEMCPY(&float_val, &buf, sizeof(double)); - bind_parameters[param_idx] = std::make_pair( - type::TypeId::DECIMAL, std::to_string(float_val)); - param_values[param_idx] = - type::ValueFactory::GetDecimalValue(float_val).Copy(); - break; - } - case PostgresValueType::VARBINARY: { - bind_parameters[param_idx] = std::make_pair( - type::TypeId::VARBINARY, - std::string(reinterpret_cast(¶m[0]), param_len)); - param_values[param_idx] = type::ValueFactory::GetVarbinaryValue( - ¶m[0], param_len, true); - break; - } - default: { - LOG_ERROR( - "Binary Postgres protocol does not support data type '%s' [%d]", - PostgresValueTypeToString(pg_value_type).c_str(), - param_types[param_idx]); - break; - } - } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != - type::TypeId::INVALID); - } - } - } - auto end = pkt->ptr; - return end - begin; -} - -ProcessResult PostgresProtocolHandler::ExecDescribeMessage(InputPacket *pkt) { - if (skipped_stmt_) { - // send 'no-data' message - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::NO_DATA_RESPONSE; - responses_.push_back(std::move(response)); - return ProcessResult::COMPLETE; - } - - ByteBuf mode; - std::string portal_name; - PacketGetBytes(pkt, 1, mode); - GetStringToken(pkt, portal_name); - if (mode[0] == 'P') { - LOG_TRACE("Describe a portal"); - auto portal_itr = portals_.find(portal_name); - - // TODO: error handling here - // Ahmed: This is causing the continuously running thread - // Changed the function signature to return boolean - // when false is returned, the connection is closed - if (portal_itr == portals_.end()) { - LOG_ERROR("Did not find portal : %s", portal_name.c_str()); - std::vector tuple_descriptor; - PutTupleDescriptor(tuple_descriptor); - return ProcessResult::COMPLETE; - } - - auto portal = portal_itr->second; - if (portal == nullptr) { - LOG_ERROR("Portal does not exist : %s", portal_name.c_str()); - std::vector tuple_descriptor; - PutTupleDescriptor(tuple_descriptor); - return ProcessResult::TERMINATE; - } - - auto statement = portal->GetStatement(); - PutTupleDescriptor(statement->GetTupleDescriptor()); - } else { - LOG_TRACE("Describe a prepared statement"); - } - return ProcessResult::COMPLETE; -} - -ProcessResult PostgresProtocolHandler::ExecExecuteMessage( - InputPacket *pkt, const size_t thread_id) { - // EXECUTE message - protocol_type_ = NetworkProtocolType::POSTGRES_JDBC; - std::string error_message, portal_name; - GetStringToken(pkt, portal_name); - - // covers weird JDBC edge case of sending double BEGIN statements. Don't - // execute them - if (skipped_stmt_) { - if (skipped_query_string_ == "") { - SendEmptyQueryResponse(); - } else { - CompleteCommand(skipped_query_type_, traffic_cop_->getRowsAffected()); - } - skipped_stmt_ = false; - return ProcessResult::COMPLETE; - } - - auto portal = portals_[portal_name]; - if (portal.get() == nullptr) { - LOG_ERROR("Did not find portal : %s", portal_name.c_str()); - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - SendReadyForQuery(txn_state_); - return ProcessResult::TERMINATE; - } - - traffic_cop_->SetStatement(portal->GetStatement()); - - auto param_stat = portal->GetParamStat(); - if (traffic_cop_->GetStatement().get() == nullptr) { - LOG_ERROR("Did not find statement in portal : %s", portal_name.c_str()); - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - SendReadyForQuery(txn_state_); - return ProcessResult::TERMINATE; - } - - auto statement_name = traffic_cop_->GetStatement()->GetStatementName(); - bool unnamed = statement_name.empty(); - traffic_cop_->SetParamVal(portal->GetParameters()); - - auto status = traffic_cop_->ExecuteStatement( - traffic_cop_->GetStatement(), traffic_cop_->GetParamVal(), unnamed, - param_stat, result_format_, traffic_cop_->GetResult(), thread_id); - if (traffic_cop_->GetQueuing()) { - return ProcessResult::PROCESSING; - } - ExecExecuteMessageGetResult(status); - return ProcessResult::COMPLETE; -} - -void PostgresProtocolHandler::ExecExecuteMessageGetResult(ResultType status) { - const auto &query_type = traffic_cop_->GetStatement()->GetQueryType(); - switch (status) { - case ResultType::FAILURE: - LOG_ERROR("Failed to execute: %s", - traffic_cop_->GetErrorMessage().c_str()); - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - traffic_cop_->GetErrorMessage()}}); - return; - - case ResultType::ABORTED: { - // It's not an ABORT query but Peloton aborts the transaction - if (query_type != QueryType::QUERY_ROLLBACK) { - LOG_DEBUG("Failed to execute: Conflicting txn aborted"); - // Send an error response if the abort is not due to ROLLBACK query - SendErrorResponse({{NetworkMessageType::SQLSTATE_CODE_ERROR, - SqlStateErrorCodeToString( - SqlStateErrorCode::SERIALIZATION_ERROR)}}); - } - return; - } - case ResultType::TO_ABORT: { - // User keeps issuing queries in a transaction that should be aborted - std::string error_message = - "current transaction is aborted, commands ignored until end of " - "transaction block"; - SendErrorResponse( - {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); - SendReadyForQuery(NetworkTransactionStateType::IDLE); - return; - } - default: { - auto tuple_descriptor = - traffic_cop_->GetStatement()->GetTupleDescriptor(); - SendDataRows(traffic_cop_->GetResult(), tuple_descriptor.size()); - CompleteCommand(query_type, traffic_cop_->getRowsAffected()); - return; - } - } -} - -void PostgresProtocolHandler::GetResult() { - traffic_cop_->ExecuteStatementPlanGetResult(); - auto status = traffic_cop_->ExecuteStatementGetResult(); - switch (protocol_type_) { - case NetworkProtocolType::POSTGRES_JDBC: - LOG_TRACE("JDBC result"); - ExecExecuteMessageGetResult(status); - break; - case NetworkProtocolType::POSTGRES_PSQL: - LOG_TRACE("PSQL result"); - ExecQueryMessageGetResult(status); - } -} - -void PostgresProtocolHandler::ExecCloseMessage(InputPacket *pkt) { - uchar close_type = 0; - std::string name; - PacketGetByte(pkt, close_type); - PacketGetString(pkt, 0, name); - switch (close_type) { - case 'S': { - LOG_TRACE("Deleting statement %s from cache", name.c_str()); - statement_cache_.DeleteStatement(name); - break; - } - case 'P': { - LOG_TRACE("Deleting portal %s from cache", name.c_str()); - auto portal_itr = portals_.find(name); - if (portal_itr != portals_.end()) { - // delete portal if it exists - portals_.erase(portal_itr); - } - break; - } - default: - // do nothing, simply send close complete - break; - } - // Send close complete response - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::CLOSE_COMPLETE; - responses_.push_back(std::move(response)); -} - -bool PostgresProtocolHandler::ParseInputPacket(ReadBuffer &rbuf, - InputPacket &rpkt, - bool startup_format) { - if (!rpkt.header_parsed && !ReadPacketHeader(rbuf, rpkt, startup_format)) - return false; - - if (rpkt.is_initialized == false) { - // packet needs to be initialized with rest of the contents - if (PostgresProtocolHandler::ReadPacket(rbuf, rpkt) == false) { - // need more data - return false; - } - } - return true; -} - -// The function tries to do a preliminary read to fetch the size value and -// then reads the rest of the packet. -// Assume: Packet length field is always 32-bit int -bool PostgresProtocolHandler::ReadPacketHeader(ReadBuffer &rbuf, - InputPacket &rpkt, - bool startup) { - // All packets other than the startup packet have a 5 bytes header - size_t header_size = startup ? sizeof(int32_t) : sizeof(int32_t) + 1; - // check if header bytes are available - if (!rbuf.HasMore(header_size)) return false; -<<<<<<< HEAD - if (!startup) rpkt.msg_type = rbuf.ReadRawValue(); -======= - if (!startup) rpkt.msg_type = rbuf.ReadValue(); ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 - - // get packet size from the header - // extract packet contents size - // content lengths should exclude the length bytes -<<<<<<< HEAD - rpkt.len = rbuf.ReadValue() - sizeof(uint32_t); -======= - rpkt.len = ntohl(rbuf.ReadValue()) - sizeof(uint32_t); ->>>>>>> a045cfc95bf349742a8101aee65e22efd9ec8096 - - // do we need to use the extended buffer for this packet? - rpkt.is_extended = (rpkt.len > rbuf.Capacity()); - - if (rpkt.is_extended) { - LOG_TRACE("Using extended buffer for pkt size:%ld", rpkt.len); - // reserve space for the extended buffer - rpkt.ReserveExtendedBuffer(); - } - // we have processed the data, move buffer pointer - rpkt.header_parsed = true; - return true; -} - -// Tries to read the contents of a single packet, returns true on success, false -// on failure. -bool PostgresProtocolHandler::ReadPacket(ReadBuffer &rbuf, InputPacket &rpkt) { - if (rpkt.is_extended) { - // extended packet mode - auto bytes_available = rbuf.BytesAvailable(); - auto bytes_required = rpkt.ExtendedBytesRequired(); - // read minimum of the two ranges - auto read_size = std::min(bytes_available, bytes_required); - rpkt.AppendToExtendedBuffer(rbuf.Begin() + rbuf.offset_, - rbuf.Begin() + rbuf.offset_ + read_size); - // data has been copied, move ptr - rbuf.offset_ += read_size; - if (bytes_required > bytes_available) { - // more data needs to be read - return false; - } - // all the data has been read - rpkt.InitializePacket(); - return true; - } else { - if (rbuf.HasMore(rpkt.len) == false) { - // data not available yet, return - return false; - } - // Initialize the packet's "contents" - rpkt.InitializePacket(rbuf.offset_, rbuf.Begin()); - // We have processed the data, move buffer pointer - rbuf.offset_ += rpkt.len; - } - - return true; -} - -/* - * process_startup_packet - Processes the startup packet - * (after the size field of the header). - */ -ProcessResult PostgresProtocolHandler::ProcessInitialPacket(InputPacket *pkt) { - int32_t proto_version = PacketGetInt(pkt, sizeof(int32_t)); - LOG_INFO("protocol version: %d", proto_version); - - force_flush_ = true; - // TODO(Yuchen): consider more about return value - if (proto_version == SSL_MESSAGE_VERNO) { - LOG_TRACE("process SSL MESSAGE"); - std::unique_ptr response(new OutputPacket()); - bool ssl_able = (PelotonServer::GetSSLLevel() != SSLLevel::SSL_DISABLE); - response->msg_type = - ssl_able ? NetworkMessageType::SSL_YES : NetworkMessageType::SSL_NO; - response->single_type_pkt = true; - responses_.push_back(std::move(response)); - return ssl_able ? ProcessResult::NEED_SSL_HANDSHAKE - : ProcessResult::COMPLETE; - } else { - LOG_TRACE("process startup packet"); - return ProcessStartupPacket(pkt, proto_version); - } -} - -ProcessResult PostgresProtocolHandler::ProcessStartupPacket( - InputPacket *pkt, int32_t proto_version) { - std::string token, value; - - // Only protocol version 3 is supported - if (PROTO_MAJOR_VERSION(proto_version) != 3) { - LOG_ERROR("Protocol error: Only protocol version 3 is supported."); - SendErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - "Protocol Version Not Support"}}); - return ProcessResult::TERMINATE; - } - - // TODO(Yuchen): check for more malformed cases - while (pkt->ptr < pkt->len) { - GetStringToken(pkt, token); - LOG_TRACE("Option key is %s", token.c_str()); - if (pkt->ptr >= pkt->len) break; - GetStringToken(pkt, value); - LOG_TRACE("Option value is %s", token.c_str()); - cmdline_options_[token] = value; - if (token.compare("database") == 0) { - traffic_cop_->SetDefaultDatabaseName(value); - } - } - - // Send AuthRequestOK to client - // TODO(Yuchen): Peloton does not do any kind of trust authentication now. - // For example, no password authentication. - SendStartupResponse(); - - init_stage_ = false; - force_flush_ = true; - return ProcessResult::COMPLETE; -} - -ProcessResult PostgresProtocolHandler::Process(ReadBuffer &rbuf, - const size_t thread_id) { - if (!ParseInputPacket(rbuf, request_, init_stage_)) - return ProcessResult::MORE_DATA_REQUIRED; - - ProcessResult process_status = - init_stage_ ? ProcessInitialPacket(&request_) - : ProcessNormalPacket(&request_, thread_id); - - request_.Reset(); - - return process_status; -} - -ProcessResult PostgresProtocolHandler::ProcessNormalPacket( - InputPacket *pkt, const size_t thread_id) { - LOG_TRACE("Message type: %c", static_cast(pkt->msg_type)); - // We don't set force_flush to true for `PBDE` messages because they're - // part of the extended protocol. Buffer responses and don't flush until - // we see a SYNC - switch (pkt->msg_type) { - case NetworkMessageType::SIMPLE_QUERY_COMMAND: { - LOG_TRACE("SIMPLE_QUERY_COMMAND"); - SetFlushFlag(true); - return ExecQueryMessage(pkt, thread_id); - } - case NetworkMessageType::PARSE_COMMAND: { - LOG_TRACE("PARSE_COMMAND"); - ExecParseMessage(pkt); - } break; - case NetworkMessageType::BIND_COMMAND: { - LOG_TRACE("BIND_COMMAND"); - ExecBindMessage(pkt); - } break; - case NetworkMessageType::DESCRIBE_COMMAND: { - LOG_TRACE("DESCRIBE_COMMAND"); - return ExecDescribeMessage(pkt); - } - case NetworkMessageType::EXECUTE_COMMAND: { - LOG_TRACE("EXECUTE_COMMAND"); - return ExecExecuteMessage(pkt, thread_id); - } - case NetworkMessageType::SYNC_COMMAND: { - LOG_TRACE("SYNC_COMMAND"); - SendReadyForQuery(txn_state_); - SetFlushFlag(true); - } break; - case NetworkMessageType::CLOSE_COMMAND: { - LOG_TRACE("CLOSE_COMMAND"); - ExecCloseMessage(pkt); - } break; - case NetworkMessageType::TERMINATE_COMMAND: { - LOG_TRACE("TERMINATE_COMMAND"); - SetFlushFlag(true); - return ProcessResult::TERMINATE; - } - case NetworkMessageType::NULL_COMMAND: { - LOG_TRACE("NULL"); - SetFlushFlag(true); - return ProcessResult::TERMINATE; - } - default: { - LOG_ERROR("Packet type not supported yet: %d (%c)", - static_cast(pkt->msg_type), - static_cast(pkt->msg_type)); - } - } - return ProcessResult::COMPLETE; -} -void PostgresProtocolHandler::MakeHardcodedParameterStatus( - const std::pair &kv) { - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::PARAMETER_STATUS; - PacketPutStringWithTerminator(response.get(), kv.first); - PacketPutStringWithTerminator(response.get(), kv.second); - responses_.push_back(std::move(response)); -} - -void PostgresProtocolHandler::PutTupleDescriptor( - const std::vector &tuple_descriptor) { - if (tuple_descriptor.empty()) return; - - std::unique_ptr pkt(new OutputPacket()); - pkt->msg_type = NetworkMessageType::ROW_DESCRIPTION; - PacketPutInt(pkt.get(), tuple_descriptor.size(), 2); - - for (auto col : tuple_descriptor) { - PacketPutStringWithTerminator(pkt.get(), std::get<0>(col)); - // TODO: Table Oid (int32) - PacketPutInt(pkt.get(), 0, 4); - // TODO: Attr id of column (int16) - PacketPutInt(pkt.get(), 0, 2); - // Field data type (int32) - PacketPutInt(pkt.get(), std::get<1>(col), 4); - // Data type size (int16) - PacketPutInt(pkt.get(), std::get<2>(col), 2); - // Type modifier (int32) - PacketPutInt(pkt.get(), -1, 4); - // Format code for text - PacketPutInt(pkt.get(), 0, 2); - } - responses_.push_back(std::move(pkt)); -} - -void PostgresProtocolHandler::SendDataRows(std::vector &results, - int colcount) { - if (results.empty() || colcount == 0) return; - - size_t numrows = results.size() / colcount; - - // 1 packet per row - for (size_t i = 0; i < numrows; i++) { - std::unique_ptr pkt(new OutputPacket()); - pkt->msg_type = NetworkMessageType::DATA_ROW; - PacketPutInt(pkt.get(), colcount, 2); - for (int j = 0; j < colcount; j++) { - auto content = results[i * colcount + j]; - if (content.size() == 0) { - // content is NULL - PacketPutInt(pkt.get(), NULL_CONTENT_SIZE, 4); - // no value bytes follow - } else { - // length of the row attribute - PacketPutInt(pkt.get(), content.size(), 4); - // contents of the row attribute - PacketPutString(pkt.get(), content); - } - } - responses_.push_back(std::move(pkt)); - } - traffic_cop_->setRowsAffected(numrows); -} - -void PostgresProtocolHandler::CompleteCommand(const QueryType &query_type, - int rows) { - std::unique_ptr pkt(new OutputPacket()); - pkt->msg_type = NetworkMessageType::COMMAND_COMPLETE; - std::string tag = QueryTypeToString(query_type); - switch (query_type) { - /* After Begin, we enter a txn block */ - case QueryType::QUERY_BEGIN: - txn_state_ = NetworkTransactionStateType::BLOCK; - break; - /* After commit, we end the txn block */ - case QueryType::QUERY_COMMIT: - /* After rollback, the txn block is ended */ - case QueryType::QUERY_ROLLBACK: - txn_state_ = NetworkTransactionStateType::IDLE; - break; - case QueryType::QUERY_INSERT: - tag += " 0 " + std::to_string(rows); - break; - case QueryType::QUERY_CREATE_TABLE: - case QueryType::QUERY_CREATE_DB: - case QueryType::QUERY_CREATE_INDEX: - case QueryType::QUERY_CREATE_TRIGGER: - case QueryType::QUERY_PREPARE: - break; - default: - tag += " " + std::to_string(rows); - } - PacketPutStringWithTerminator(pkt.get(), tag); - responses_.push_back(std::move(pkt)); -} - -/* - * put_empty_query_response - Informs the client that an empty query was sent - */ -void PostgresProtocolHandler::SendEmptyQueryResponse() { - std::unique_ptr response(new OutputPacket()); - response->msg_type = NetworkMessageType::EMPTY_QUERY_RESPONSE; - responses_.push_back(std::move(response)); -} - -/* - * send_error_response - Sends the passed string as an error response. - * For now, it only supports the human readable 'M' message body - */ -void PostgresProtocolHandler::SendErrorResponse( - std::vector> error_status) { - std::unique_ptr pkt(new OutputPacket()); - pkt->msg_type = NetworkMessageType::ERROR_RESPONSE; - - for (auto entry : error_status) { - PacketPutByte(pkt.get(), static_cast(entry.first)); - PacketPutStringWithTerminator(pkt.get(), entry.second); - } - - // put null terminator - PacketPutByte(pkt.get(), 0); - - // don't care if write finished or not, we are closing anyway - responses_.push_back(std::move(pkt)); -} - -void PostgresProtocolHandler::SendReadyForQuery( - NetworkTransactionStateType txn_status) { - std::unique_ptr pkt(new OutputPacket()); - pkt->msg_type = NetworkMessageType::READY_FOR_QUERY; - - PacketPutByte(pkt.get(), static_cast(txn_status)); - - responses_.push_back(std::move(pkt)); -} - -void PostgresProtocolHandler::Reset() { - ProtocolHandler::Reset(); - statement_cache_.Clear(); - result_format_.clear(); - traffic_cop_->Reset(); - txn_state_ = NetworkTransactionStateType::IDLE; - skipped_stmt_ = false; - skipped_query_string_.clear(); - portals_.clear(); -} - -} // namespace network -} // namespace peloton diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index 442f16ddfdf..96938f628fc 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#pragma once +#include "planner/plan_util.h" #include "network/postgres_protocol_interpreter.h" #define MAKE_COMMAND(type) \ @@ -101,5 +101,158 @@ std::shared_ptr PostgresProtocolInterpreter::PacketToCom } } +void PostgresProtocolInterpreter::CompleteCommand(PostgresPacketWriter &out, + const QueryType &query_type, + int rows) { + + std::string tag = QueryTypeToString(query_type); + switch (query_type) { + /* After Begin, we enter a txn block */ + case QueryType::QUERY_BEGIN: + state_.txn_state_ = NetworkTransactionStateType::BLOCK; + break; + /* After commit, we end the txn block */ + case QueryType::QUERY_COMMIT: + /* After rollback, the txn block is ended */ + case QueryType::QUERY_ROLLBACK: + state_.txn_state_ = NetworkTransactionStateType::IDLE; + break; + case QueryType::QUERY_INSERT: + tag += " 0 " + std::to_string(rows); + break; + case QueryType::QUERY_CREATE_TABLE: + case QueryType::QUERY_CREATE_DB: + case QueryType::QUERY_CREATE_INDEX: + case QueryType::QUERY_CREATE_TRIGGER: + case QueryType::QUERY_PREPARE: + break; + default: + tag += " " + std::to_string(rows); + } + out.BeginPacket(NetworkMessageType::COMMAND_COMPLETE) + .AppendString(tag); +} + +void PostgresProtocolInterpreter::ExecQueryMessageGetResult(PostgresPacketWriter &out, + ResultType status) { + std::vector tuple_descriptor; + if (status == ResultType::SUCCESS) { + tuple_descriptor = state_.statement_->GetTupleDescriptor(); + } else if (status == ResultType::FAILURE) { // check status + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state_.error_message_}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return; + } else if (status == ResultType::TO_ABORT) { + std::string error_message = + "current transaction is aborted, commands ignored until end of " + "transaction block"; + out.WriteErrorResponse( + {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return; + } + + // send the attribute names + out.WriteTupleDescriptor(tuple_descriptor); + out.WriteDataRows(state_.result_, tuple_descriptor.size()); + // TODO(Tianyu): WTF? + state_.rows_affected_ = state_.result_.size() / tuple_descriptor.size(); + + CompleteCommand(out, + state_.statement_->GetQueryType(), + state_.rows_affected_); + + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); +} + +void PostgresProtocolInterpreter::ExecExecuteMessageGetResult(PostgresPacketWriter &out, peloton::ResultType status) { + const auto &query_type = state_.statement_->GetQueryType(); + switch (status) { + case ResultType::FAILURE: + LOG_ERROR("Failed to execute: %s", + state_.error_message_.c_str()); + out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + state_.error_message_}}); + return; + case ResultType::ABORTED: { + // It's not an ABORT query but Peloton aborts the transaction + if (query_type != QueryType::QUERY_ROLLBACK) { + LOG_DEBUG("Failed to execute: Conflicting txn aborted"); + // Send an error response if the abort is not due to ROLLBACK query + out.WriteErrorResponse({{NetworkMessageType::SQLSTATE_CODE_ERROR, + SqlStateErrorCodeToString( + SqlStateErrorCode::SERIALIZATION_ERROR)}}); + } + return; + } + case ResultType::TO_ABORT: { + // User keeps issuing queries in a transaction that should be aborted + std::string error_message = + "current transaction is aborted, commands ignored until end of " + "transaction block"; + out.WriteErrorResponse( + {{NetworkMessageType::HUMAN_READABLE_ERROR, error_message}}); + out.WriteReadyForQuery(NetworkTransactionStateType::IDLE); + return; + } + default: { + auto tuple_descriptor = + state_.statement_->GetTupleDescriptor(); + out.WriteDataRows(state_.result_, tuple_descriptor.size()); + state_.rows_affected_ = state_.result_.size() / tuple_descriptor.size(); + CompleteCommand(out, query_type, state_.rows_affected_); + return; + } + } +} + +ResultType PostgresProtocolInterpreter::ExecQueryExplain(const std::string &query, + peloton::parser::ExplainStatement &explain_stmt) { + std::unique_ptr unnamed_sql_stmt_list( + new parser::SQLStatementList()); + unnamed_sql_stmt_list->PassInStatement(std::move(explain_stmt.real_sql_stmt)); + auto stmt = tcop::Tcop::GetInstance().PrepareStatement(state_, "explain", query, + std::move(unnamed_sql_stmt_list)); + ResultType status; + if (stmt != nullptr) { + state_.statement_ = stmt; + std::vector plan_info = StringUtil::Split( + planner::PlanUtil::GetInfo(stmt->GetPlanTree().get()), '\n'); + const std::vector tuple_descriptor = { + tcop::Tcop::GetInstance().GetColumnFieldForValueType("Query plan", + type::TypeId::VARCHAR)}; + stmt->SetTupleDescriptor(tuple_descriptor); + state_.result_ = plan_info; + status = ResultType::SUCCESS; + } else { + status = ResultType::FAILURE; + } + return status; +} + +bool PostgresProtocolInterpreter::HardcodedExecuteFilter(peloton::QueryType query_type) { + switch (query_type) { + // Skip SET + case QueryType::QUERY_SET: + case QueryType::QUERY_SHOW: + return false; + // Skip duplicate BEGIN + case QueryType::QUERY_BEGIN: + if (state_.txn_state_ == NetworkTransactionStateType::BLOCK) { + return false; + } + break; + // Skip duplicate Commits and Rollbacks + case QueryType::QUERY_COMMIT: + case QueryType::QUERY_ROLLBACK: + if (state_.txn_state_ == NetworkTransactionStateType::IDLE) { + return false; + } + default: + break; + } + return true; +} } // namespace network } // namespace peloton \ No newline at end of file diff --git a/src/network/protocol_handler.cpp b/src/network/protocol_handler.cpp deleted file mode 100644 index 20a56351f85..00000000000 --- a/src/network/protocol_handler.cpp +++ /dev/null @@ -1,38 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// protocol_handler.cpp -// -// Identification: src/network/protocol_handler.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "network/protocol_handler.h" - -#include - -namespace peloton { -namespace network { - -ProtocolHandler::ProtocolHandler(tcop::TrafficCop *traffic_cop) { - this->traffic_cop_ = traffic_cop; -} - -ProtocolHandler::~ProtocolHandler() {} - -ProcessResult ProtocolHandler::Process(ReadBuffer &, const size_t) { - return ProcessResult::TERMINATE; -} - -void ProtocolHandler::Reset() { - SetFlushFlag(false); - responses_.clear(); - request_.Reset(); -} - -void ProtocolHandler::GetResult() {} -} // namespace network -} // namespace peloton diff --git a/src/network/protocol_handler_factory.cpp b/src/network/protocol_handler_factory.cpp deleted file mode 100644 index 9df0d5fad86..00000000000 --- a/src/network/protocol_handler_factory.cpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// protocol_handler_factory.cpp -// -// Identification: src/network/protocol_handler_factory.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "network/protocol_handler_factory.h" -#include "network/postgres_protocol_handler.h" - -namespace peloton { -namespace network { -std::unique_ptr ProtocolHandlerFactory::CreateProtocolHandler( - ProtocolHandlerType type, tcop::TrafficCop *traffic_cop) { - switch (type) { - case ProtocolHandlerType::Postgres: { - return std::unique_ptr( - new PostgresProtocolHandler(traffic_cop)); - } - default: - return nullptr; - } -} -} // namespace network -} // namespace peloton diff --git a/src/traffic_cop/client_transaction_handle.cpp b/src/traffic_cop/client_transaction_handle.cpp deleted file mode 100644 index 61513455cc8..00000000000 --- a/src/traffic_cop/client_transaction_handle.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// client_transaction_handle.cpp -// -// Identification: src/traffic_cop/client_transaction_handle.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "traffic_cop/client_transaction_handle.h" -#include -namespace peloton { -namespace tcop { - - /* Function implementations of SingleStmtClientTxnHandler */ - TxnContext *SingleStmtClientTxnHandler::ImplicitBegin(const size_t thread_id, ClientTxnHandle &handle) { - switch (handle.txn_state_) { - case TransactionState::IDLE: { - handle.txn_ = TxnManagerFactory::GetInstance().BeginTransaction(thread_id); - handle.txn_state_ =TransactionState::STARTED; - } - case TransactionState::STARTED: - case TransactionState::FAILING: - case TransactionState::ABORTING: - break; - } - return handle.txn_; - } - - void SingleStmtClientTxnHandler::End(ClientTxnHandle &handle) { - // TODO Implement this function - } - - void SingleStmtClientTxnHandler::Abort(ClientTxnHandle &handle) { - // TODO Implement this function - } - - - /* Function implementations of MultiStmtsClientTxnHandler */ - TxnContext *MultiStmtsClientTxnHandler::ImplicitBegin(const size_t, ClientTxnHandle &handle_) { - return handle_.GetTxn(); - } - - TxnContext *MultiStmtsClientTxnHandler::ExplicitBegin(const size_t thread_id = 0, ClientTxnHandle & handle){ - switch (handle.txn_state_) { - case TransactionState::IDLE: { - handle.txn_ = TxnManagerFactory::GetInstance().BeginTransaction(thread_id); - handle.txn_state_ = TransactionState::STARTED; - } - case TransactionState::STARTED: - TxnManagerFactory::GetInstance().AbortTransaction(handle.txn_); - handle.txn_state_ = TransactionState::ABORTING; - throw TransactionException("Current Transaction started already"); - case TransactionState::FAILING: - case TransactionState::ABORTING: - break; - } - return handle.txn_; - } - - bool MultiStmtsClientTxnHandler::Commit(ClientTxnHandle &handler) { - // TODO implement this function - return false; - } - - void MultiStmtsClientTxnHandler::Abort(ClientTxnHandle &handler) { - // TODO implement this function - } - - /* Function implementations of ClientTxnHandle */ - TxnContext *ClientTxnHandle::ImplicitBegin(const size_t thread_id) { - return handler_->ImplicitBegin(thread_id, *this); - } - - TxnContext *ClientTxnHandle::ExplicitBegin(const size_t thread_id) { - if (single_stmt_handler_) { - ChangeToMultiStmtsHandler(); - } - return handler_->ExplicitBegin(thread_id, *this); - } - - void ClientTxnHandle::ImplicitEnd() { - handler_->End(*this); - if (txn_state_ == TransactionState::IDLE && !single_stmt_handler_) { - ChangeToSingleStmtHandler(); - } - } - - void ClientTxnHandle::ExplicitAbort() { - handler_->Abort(*this); - if (!single_stmt_handler_) - ChangeToSingleStmtHandler(); - } - - bool ClientTxnHandle::ExplicitCommit() { - bool success = handler_->Commit(*this); - if (success && !single_stmt_handler_) { - ChangeToSingleStmtHandler(); - } - return success; - } - -} -} \ No newline at end of file diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index ad6e0947385..3f79c662f05 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -14,59 +14,89 @@ #include "planner/plan_util.h" #include "binder/bind_node_visitor.h" #include "traffic_cop/tcop.h" +#include "expression/expression_util.h" +#include "concurrency/transaction_context.h" +#include "concurrency/transaction_manager_factory.h" namespace peloton { namespace tcop { -// Prepare a statement -bool tcop::PrepareStatement(ClientProcessState &state, - const std::string &query_string, - const std::string &statement_name) { - try { - state.txn_handle_.ImplicitBegin(state.thread_id_); - // parse the query - auto &peloton_parser = parser::PostgresParser::GetInstance(); - auto sql_stmt_list = peloton_parser.BuildParseTree(query_string); - - // When the query is empty(such as ";" or ";;", still valid), - // the parse tree is empty, parser will return nullptr. - if (sql_stmt_list != nullptr && !sql_stmt_list->is_valid) - throw ParserException("Error Parsing SQL statement"); - - // TODO(Yuchen): Hack. We only process the first statement in the packet now. - // We should store the rest of statements that will not be processed right - // away. For the hack, in most cases, it works. Because for example in psql, - // one packet contains only one query. But when using the pipeline mode in - // Libpqxx, it sends multiple query in one packet. In this case, it's - // incorrect. - StatementType stmt_type = sql_stmt_list->GetStatement(0)->GetType(); - QueryType query_type = - StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); - - auto statement = std::make_shared(statement_name, - query_type, - query_string, - std::move(sql_stmt_list)); - - // Empty statement edge case - if (statement->GetStmtParseTreeList() == nullptr || - statement->GetStmtParseTreeList()->GetNumStatements() == 0) { - state.statement_cache_.AddStatement( - std::make_shared(statement_name, - QueryType::QUERY_INVALID, - query_string, - std::move(statement->PassStmtParseTreeList()))); - return true; +std::shared_ptr Tcop::PrepareStatement(ClientProcessState &state, + const std::string &statement_name, + const std::string &query_string, + std::unique_ptr &&sql_stmt_list) { + LOG_TRACE("Prepare Statement query: %s", query_string.c_str()); + + // Empty statement + // TODO (Tianyi) Read through the parser code to see if this is appropriate + if (sql_stmt_list == nullptr || sql_stmt_list->GetNumStatements() == 0) + // TODO (Tianyi) Do we need another query type called QUERY_EMPTY? + return std::make_shared(statement_name, + QueryType::QUERY_INVALID, + query_string, + std::move(sql_stmt_list)); + + StatementType stmt_type = sql_stmt_list->GetStatement(0)->GetType(); + QueryType query_type = + StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); + auto statement = std::make_shared(statement_name, + query_type, + query_string, + std::move(sql_stmt_list)); + + // TODO(Tianyu): Issue #1441. Hopefully Tianyi will fix this in his later + // refactor + + // We can learn transaction's states, BEGIN, COMMIT, ABORT, or ROLLBACK from + // member variables, tcop_txn_state_. We can also get single-statement txn or + // multi-statement txn from member variable single_statement_txn_ + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + // --multi-statements except BEGIN in a transaction + if (!state.tcop_txn_state_.empty()) { + state.single_statement_txn_ = false; + // multi-statment txn has been aborted, just skip this query, + // and do not need to parse or execute this query anymore. + // Do not return nullptr in case that 'COMMIT' cannot be execute, + // because nullptr will directly return ResultType::FAILURE to + // packet_manager + if (state.tcop_txn_state_.top().second == ResultType::ABORTED) + return statement; + } else { + // Begin new transaction when received single-statement query or "BEGIN" + // from multi-statement query + if (statement->GetQueryType() == + QueryType::QUERY_BEGIN) { // only begin a new transaction + // note this transaction is not single-statement transaction + LOG_TRACE("BEGIN"); + state.single_statement_txn_ = false; + } else { + // single statement + LOG_TRACE("SINGLE TXN"); + state.single_statement_txn_ = true; + } + auto txn = txn_manager.BeginTransaction(state.thread_id_); + // this shouldn't happen + if (txn == nullptr) { + LOG_TRACE("Begin txn failed"); } + // initialize the current result as success + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + } + + if (settings::SettingsManager::GetBool(settings::SettingId::brain)) { + state.tcop_txn_state_.top().first->AddQueryString(query_string.c_str()); + } + // TODO(Tianyi) Move Statement Planing into Statement's method + // to increase coherence + try { // Run binder - auto bind_node_visitor = binder::BindNodeVisitor(state.txn_handle_.GetTxn(), - state.db_name_); + auto bind_node_visitor = binder::BindNodeVisitor( + state.tcop_txn_state_.top().first, state.db_name_); bind_node_visitor.BindNameToNode( statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = state.optimizer_-> - BuildPelotonPlanTree(statement->GetStmtParseTreeList(), - state.txn_handle_.GetTxn()); + auto plan = state.optimizer_->BuildPelotonPlanTree( + statement->GetStmtParseTreeList(), state.tcop_txn_state_.top().first); statement->SetPlanTree(plan); // Get the tables that our plan references so that we know how to // invalidate it at a later point when the catalog changes @@ -75,27 +105,28 @@ bool tcop::PrepareStatement(ClientProcessState &state, statement->SetReferencedTables(table_oids); if (query_type == QueryType::QUERY_SELECT) { - auto tuple_descriptor = GenerateTupleDescriptor( + auto tuple_descriptor = GenerateTupleDescriptor(state, statement->GetStmtParseTreeList()->GetStatement(0)); statement->SetTupleDescriptor(tuple_descriptor); LOG_TRACE("select query, finish setting"); } - - state.statement_cache_.AddStatement(statement); - } catch (Exception &e) { - // TODO(Tianyi) implicit end the txn here state.error_message_ = e.what(); - return false; + tcop::Tcop::GetInstance().ProcessInvalidStatement(state); + return nullptr; } - // TODO(Tianyi) catch txn exception - return true; + +#ifdef LOG_DEBUG_ENABLED + if (statement->GetPlanTree().get() != nullptr) { + LOG_TRACE("Statement Prepared: %s", statement->GetInfo().c_str()); + LOG_TRACE("%s", statement->GetPlanTree().get()->GetInfo().c_str()); + } +#endif + return statement; } -bool tcop::ExecuteStatement(ClientProcessState &state, - const std::vector &result_format, - std::vector &result, - const CallbackFunc &callback) { +ResultType Tcop::ExecuteStatement(ClientProcessState &state, + CallbackFunc callback) { LOG_TRACE("Execute Statement of name: %s", state.statement_->GetStatementName().c_str()); @@ -110,61 +141,375 @@ bool tcop::ExecuteStatement(ClientProcessState &state, try { switch (state.statement_->GetQueryType()) { - case QueryType::QUERY_BEGIN: { - state.txn_handle_.ExplicitBegin(state.thread_id_); - return true; - } - case QueryType::QUERY_COMMIT: { - if (!state.txn_handle_.ExplicitCommit()) { - state.p_status_.m_result = ResultType::FAILURE; - //TODO set error message - } - return true; - } - case QueryType::QUERY_ROLLBACK: { - state.txn_handle_.ExplicitAbort(); - return true; - } - default: { + case QueryType::QUERY_BEGIN:return BeginQueryHelper(state); + case QueryType::QUERY_COMMIT:return CommitQueryHelper(state); + case QueryType::QUERY_ROLLBACK:return AbortQueryHelper(state); + default: // The statement may be out of date // It needs to be replan - auto txn = state.txn_handle_.ImplicitBegin(state.thread_id_); if (state.statement_->GetNeedsReplan()) { // TODO(Tianyi) Move Statement Replan into Statement's method // to increase coherence + auto bind_node_visitor = binder::BindNodeVisitor( + state.tcop_txn_state_.top().first, state.db_name_); + bind_node_visitor.BindNameToNode( + state.statement_->GetStmtParseTreeList()->GetStatement(0)); auto plan = state.optimizer_->BuildPelotonPlanTree( - state.statement_->GetStmtParseTreeList(), txn); + state.statement_->GetStmtParseTreeList(), + state.tcop_txn_state_.top().first); state.statement_->SetPlanTree(plan); - state.statement_->SetNeedsReplan(false); + state.statement_->SetNeedsReplan(true); } - auto plan = state.statement_->GetPlanTree(); - auto params = state.param_values_; - - auto on_complete = [callback, &](executor::ExecutionResult p_status, - std::vector &&values) { - state.p_status_ = p_status; - state.error_message_ = std::move(p_status.m_error_message); - result = std::move(values); - callback(); - }; - - auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([txn, on_complete, &] { - executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), - txn, - state.param_values_, - result_format, - on_complete); - }); - return false; - } + ExecuteHelper(state, callback); + if (state.is_queuing_) + return ResultType::QUEUING; + else + return ExecuteStatementGetResult(state); } } catch (Exception &e) { - state.p_status_.m_result = ResultType::FAILURE; state.error_message_ = e.what(); - return true; + return ResultType::FAILURE; + } +} + +bool Tcop::BindParamsForCachePlan(ClientProcessState &state, + const std::vector> &exprs) { + if (state.tcop_txn_state_.empty()) { + state.single_statement_txn_ = true; + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(state.thread_id_); + // this shouldn't happen + if (txn == nullptr) { + LOG_ERROR("Begin txn failed"); + } + // initialize the current result as success + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + } + // Run binder + auto bind_node_visitor = + binder::BindNodeVisitor(state.tcop_txn_state_.top().first, + state.db_name_); + + std::vector param_values; + for (const auto &expr :exprs) { + if (!expression::ExpressionUtil::IsValidStaticExpression(expr.get())) { + state.error_message_ = "Invalid Expression Type"; + return false; + } + expr->Accept(&bind_node_visitor); + // TODO(Yuchen): need better check for nullptr argument + param_values.push_back(expr->Evaluate(nullptr, nullptr, nullptr)); + } + if (!param_values.empty()) { + state.statement_->GetPlanTree()->SetParameterValues(¶m_values); + } + state.param_values_ = param_values; + return true; +} + +std::vector Tcop::GenerateTupleDescriptor(ClientProcessState &state, + parser::SQLStatement *sql_stmt) { + std::vector tuple_descriptor; + if (sql_stmt->GetType() != StatementType::SELECT) return tuple_descriptor; + auto select_stmt = (parser::SelectStatement *) sql_stmt; + + // TODO(Bowei): this is a hack which I don't have time to fix now + // but it replaces a worse hack that was here before + // What should happen here is that plan nodes should store + // the schema of their expected results and here we should just read + // it and put it in the tuple descriptor + + // Get the columns information and set up + // the columns description for the returned results + // Set up the table + std::vector all_columns; + + // Check if query only has one Table + // Example : SELECT * FROM A; + GetTableColumns(state, select_stmt->from_table.get(), all_columns); + + int count = 0; + for (auto &expr : select_stmt->select_list) { + count++; + if (expr->GetExpressionType() == ExpressionType::STAR) { + for (const auto &column : all_columns) { + tuple_descriptor.push_back( + GetColumnFieldForValueType(column.GetName(), column.GetType())); + } + } else { + std::string col_name; + if (expr->alias.empty()) { + col_name = expr->expr_name_.empty() + ? std::string("expr") + std::to_string(count) + : expr->expr_name_; + } else { + col_name = expr->alias; + } + tuple_descriptor.push_back( + GetColumnFieldForValueType(col_name, expr->GetValueType())); + } + } + + return tuple_descriptor; +} + +FieldInfo Tcop::GetColumnFieldForValueType(std::string column_name, + type::TypeId column_type) { + PostgresValueType field_type; + size_t field_size; + switch (column_type) { + case type::TypeId::BOOLEAN: + case type::TypeId::TINYINT: { + field_type = PostgresValueType::BOOLEAN; + field_size = 1; + break; + } + case type::TypeId::SMALLINT: { + field_type = PostgresValueType::SMALLINT; + field_size = 2; + break; + } + case type::TypeId::INTEGER: { + field_type = PostgresValueType::INTEGER; + field_size = 4; + break; + } + case type::TypeId::BIGINT: { + field_type = PostgresValueType::BIGINT; + field_size = 8; + break; + } + case type::TypeId::DECIMAL: { + field_type = PostgresValueType::DOUBLE; + field_size = 8; + break; + } + case type::TypeId::VARCHAR: + case type::TypeId::VARBINARY: { + field_type = PostgresValueType::TEXT; + field_size = 255; + break; + } + case type::TypeId::DATE: { + field_type = PostgresValueType::DATE; + field_size = 4; + break; + } + case type::TypeId::TIMESTAMP: { + field_type = PostgresValueType::TIMESTAMPS; + field_size = 64; // FIXME: Bytes??? + break; + } + default: { + // Type not Identified + LOG_ERROR("Unrecognized field type '%s' for field '%s'", + TypeIdToString(column_type).c_str(), column_name.c_str()); + field_type = PostgresValueType::TEXT; + field_size = 255; + break; + } } + // HACK: Convert the type into a oid_t + // This ugly and I don't like it one bit... + return std::make_tuple(column_name, static_cast(field_type), + field_size); +} + +void Tcop::GetTableColumns(ClientProcessState &state, + parser::TableRef *from_table, + std::vector &target_columns) { + if (from_table == nullptr) return; + + // Query derived table + if (from_table->select != nullptr) { + for (auto &expr : from_table->select->select_list) { + if (expr->GetExpressionType() == ExpressionType::STAR) + GetTableColumns(state, from_table->select->from_table.get(), target_columns); + else + target_columns.emplace_back(expr->GetValueType(), 0, + expr->GetExpressionName()); + } + } else if (from_table->list.empty()) { + if (from_table->join == nullptr) { + auto columns = + catalog::Catalog::GetInstance()->GetTableWithName( + from_table->GetDatabaseName(), from_table->GetSchemaName(), + from_table->GetTableName(), state.GetCurrentTxnState().first) + ->GetSchema() + ->GetColumns(); + target_columns.insert(target_columns.end(), columns.begin(), + columns.end()); + } else { + GetTableColumns(state, from_table->join->left.get(), target_columns); + GetTableColumns(state, from_table->join->right.get(), target_columns); + } + } + // Query has multiple tables. Recursively add all tables + else + for (auto &table : from_table->list) + GetTableColumns(state, table.get(), target_columns); +} + +void Tcop::ExecuteStatementPlanGetResult(ClientProcessState &state) { + if (state.p_status_.m_result == ResultType::FAILURE) return; + + auto txn_result = state.GetCurrentTxnState().first->GetResult(); + if (state.single_statement_txn_ || txn_result == ResultType::FAILURE) { + LOG_TRACE("About to commit/abort: single stmt: %d,txn_result: %s", + state.single_statement_txn_, + ResultTypeToString(txn_result).c_str()); + switch (txn_result) { + case ResultType::SUCCESS: + // Commit single statement + LOG_TRACE("Commit Transaction"); + state.p_status_.m_result = CommitQueryHelper(state); + break; + case ResultType::FAILURE: + default: + // Abort + LOG_TRACE("Abort Transaction"); + if (state.single_statement_txn_) { + LOG_TRACE("Tcop_txn_state size: %lu", tcop_txn_state_.size()); + state.p_status_.m_result = AbortQueryHelper(state); + } else { + state.tcop_txn_state_.top().second = ResultType::ABORTED; + state.p_status_.m_result = ResultType::ABORTED; + } + } + } +} + +ResultType Tcop::ExecuteStatementGetResult(ClientProcessState &state) { + LOG_TRACE("Statement executed. Result: %s", + ResultTypeToString(p_status_.m_result).c_str()); + state.rows_affected_ = state.p_status_.m_processed; + LOG_TRACE("rows_changed %d", state.p_status_.m_processed); + state.is_queuing_ = false; + return state.p_status_.m_result; +} + +void Tcop::ProcessInvalidStatement(ClientProcessState &state) { + if (state.single_statement_txn_) { + LOG_TRACE("SINGLE ABORT!"); + AbortQueryHelper(state); + } else { // multi-statment txn + if (state.tcop_txn_state_.top().second != ResultType::ABORTED) { + state.tcop_txn_state_.top().second = ResultType::ABORTED; + } + } +} + +ResultType Tcop::CommitQueryHelper(ClientProcessState &state) { +// do nothing if we have no active txns + if (state.tcop_txn_state_.empty()) return ResultType::NOOP; + auto &curr_state = state.tcop_txn_state_.top(); + state.tcop_txn_state_.pop(); + auto txn = curr_state.first; + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + // I catch the exception (ex. table not found) explicitly, + // If this exception is caused by a query in a transaction, + // I will block following queries in that transaction until 'COMMIT' or + // 'ROLLBACK' After receive 'COMMIT', see if it is rollback or really commit. + if (curr_state.second != ResultType::ABORTED) { + // txn committed + return txn_manager.CommitTransaction(txn); + } else { + // otherwise, rollback + return txn_manager.AbortTransaction(txn); + } +} + +ResultType Tcop::BeginQueryHelper(ClientProcessState &state) { + if (state.tcop_txn_state_.empty()) { + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(state.thread_id_); + // this shouldn't happen + if (txn == nullptr) { + LOG_DEBUG("Begin txn failed"); + return ResultType::FAILURE; + } + // initialize the current result as success + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + } + return ResultType::SUCCESS; +} + +ResultType Tcop::AbortQueryHelper(ClientProcessState &state) { + // do nothing if we have no active txns + if (state.tcop_txn_state_.empty()) return ResultType::NOOP; + auto &curr_state = state.tcop_txn_state_.top(); + state.tcop_txn_state_.pop(); + // explicitly abort the txn only if it has not aborted already + if (curr_state.second != ResultType::ABORTED) { + auto txn = curr_state.first; + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto result = txn_manager.AbortTransaction(txn); + return result; + } else { + delete curr_state.first; + // otherwise, the txn has already been aborted + return ResultType::ABORTED; + } +} + +executor::ExecutionResult Tcop::ExecuteHelper(ClientProcessState &state, + CallbackFunc callback) { + auto &curr_state = state.GetCurrentTxnState(); + + concurrency::TransactionContext *txn; + if (!state.tcop_txn_state_.empty()) { + txn = curr_state.first; + } else { + // No active txn, single-statement txn + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + // new txn, reset result status + curr_state.second = ResultType::SUCCESS; + state.single_statement_txn_ = true; + txn = txn_manager.BeginTransaction(state.thread_id_); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + } + + // skip if already aborted + if (curr_state.second == ResultType::ABORTED) { + // If the transaction state is ABORTED, the transaction should be aborted + // but Peloton didn't explicitly abort it yet since it didn't receive a + // COMMIT/ROLLBACK. + // Here, it receive queries other than COMMIT/ROLLBACK in an broken + // transaction, + // it should tell the client that these queries will not be executed. + state.p_status_.m_result = ResultType::TO_ABORT; + return state.p_status_; + } + + auto on_complete = [callback, &state](executor::ExecutionResult p_status, + std::vector &&values) { + state.p_status_ = p_status; + // TODO (Tianyi) I would make a decision on keeping one of p_status or + // error_message in my next PR + state.error_message_ = std::move(p_status.m_error_message); + state.result_ = std::move(values); + callback(); + }; + // TODO(Tianyu): Eliminate this copy, which is here to coerce the type + std::vector formats; + for (auto format : state.result_format_) + formats.push_back((int) format); + + auto &pool = threadpool::MonoQueuePool::GetInstance(); + pool.SubmitTask([on_complete, &state, &txn, &formats] { + executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), + txn, + state.param_values_, + formats, + on_complete); + }); + + state.is_queuing_ = true; + + LOG_TRACE("Check Tcop_txn_state Size After ExecuteHelper %lu", + state.tcop_txn_state_.size()); + return state.p_status_; } } // namespace tcop diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp deleted file mode 100644 index a8794d9cfb0..00000000000 --- a/src/traffic_cop/traffic_cop.cpp +++ /dev/null @@ -1,617 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// traffic_cop.cpp -// -// Identification: src/traffic_cop/traffic_cop.cpp -// -// Copyright (c) 2015-17, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "traffic_cop/traffic_cop.h" - -#include - -#include "binder/bind_node_visitor.h" -#include "common/internal_types.h" -#include "concurrency/transaction_context.h" -#include "concurrency/transaction_manager_factory.h" -#include "expression/expression_util.h" -#include "optimizer/optimizer.h" -#include "planner/plan_util.h" -#include "settings/settings_manager.h" -#include "threadpool/mono_queue_pool.h" - -namespace peloton { -namespace tcop { - -TrafficCop::TrafficCop() - : is_queuing_(false), - rows_affected_(0), - optimizer_(new optimizer::Optimizer()), - single_statement_txn_(true) {} - -TrafficCop::TrafficCop(void (*task_callback)(void *), void *task_callback_arg) - : optimizer_(new optimizer::Optimizer()), - single_statement_txn_(true), - task_callback_(task_callback), - task_callback_arg_(task_callback_arg) {} - -void TrafficCop::Reset() { - std::stack new_tcop_txn_state; - // clear out the stack - swap(tcop_txn_state_, new_tcop_txn_state); - optimizer_->Reset(); - results_.clear(); - param_values_.clear(); - setRowsAffected(0); -} - -TrafficCop::~TrafficCop() { - // Abort all running transactions - while (!tcop_txn_state_.empty()) { - AbortQueryHelper(); - } -} - -/* Singleton accessor - * NOTE: Used by in unit tests ONLY - */ -TrafficCop &TrafficCop::GetInstance() { - static TrafficCop tcop; - tcop.Reset(); - return tcop; -} - -TrafficCop::TcopTxnState &TrafficCop::GetDefaultTxnState() { - static TcopTxnState default_state; - default_state = std::make_pair(nullptr, ResultType::INVALID); - return default_state; -} - -TrafficCop::TcopTxnState &TrafficCop::GetCurrentTxnState() { - if (tcop_txn_state_.empty()) { - return GetDefaultTxnState(); - } - return tcop_txn_state_.top(); -} - -ResultType TrafficCop::BeginQueryHelper(size_t thread_id) { - if (tcop_txn_state_.empty()) { - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_DEBUG("Begin txn failed"); - return ResultType::FAILURE; - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - return ResultType::SUCCESS; -} - -ResultType TrafficCop::CommitQueryHelper() { - // do nothing if we have no active txns - if (tcop_txn_state_.empty()) return ResultType::NOOP; - auto &curr_state = tcop_txn_state_.top(); - tcop_txn_state_.pop(); - auto txn = curr_state.first; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // I catch the exception (ex. table not found) explicitly, - // If this exception is caused by a query in a transaction, - // I will block following queries in that transaction until 'COMMIT' or - // 'ROLLBACK' After receive 'COMMIT', see if it is rollback or really commit. - if (curr_state.second != ResultType::ABORTED) { - // txn committed - return txn_manager.CommitTransaction(txn); - } else { - // otherwise, rollback - return txn_manager.AbortTransaction(txn); - } -} - -ResultType TrafficCop::AbortQueryHelper() { - // do nothing if we have no active txns - if (tcop_txn_state_.empty()) return ResultType::NOOP; - auto &curr_state = tcop_txn_state_.top(); - tcop_txn_state_.pop(); - // explicitly abort the txn only if it has not aborted already - if (curr_state.second != ResultType::ABORTED) { - auto txn = curr_state.first; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto result = txn_manager.AbortTransaction(txn); - return result; - } else { - delete curr_state.first; - // otherwise, the txn has already been aborted - return ResultType::ABORTED; - } -} - -ResultType TrafficCop::ExecuteStatementGetResult() { - LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(p_status_.m_result).c_str()); - setRowsAffected(p_status_.m_processed); - LOG_TRACE("rows_changed %d", p_status_.m_processed); - is_queuing_ = false; - return p_status_.m_result; -} - -/* - * Execute a statement that needs a plan(so, BEGIN, COMMIT, ROLLBACK does not - * come here). - * Begin a new transaction if necessary. - * If the current transaction is already broken(for example due to previous - * invalid - * queries), directly return - * Otherwise, call ExecutePlan() - */ -executor::ExecutionResult TrafficCop::ExecuteHelper( - std::shared_ptr plan, - const std::vector ¶ms, std::vector &result, - const std::vector &result_format, size_t thread_id) { - auto &curr_state = GetCurrentTxnState(); - - concurrency::TransactionContext *txn; - if (!tcop_txn_state_.empty()) { - txn = curr_state.first; - } else { - // No active txn, single-statement txn - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // new txn, reset result status - curr_state.second = ResultType::SUCCESS; - single_statement_txn_ = true; - txn = txn_manager.BeginTransaction(thread_id); - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - - // skip if already aborted - if (curr_state.second == ResultType::ABORTED) { - // If the transaction state is ABORTED, the transaction should be aborted - // but Peloton didn't explicitly abort it yet since it didn't receive a - // COMMIT/ROLLBACK. - // Here, it receive queries other than COMMIT/ROLLBACK in an broken - // transaction, - // it should tell the client that these queries will not be executed. - p_status_.m_result = ResultType::TO_ABORT; - return p_status_; - } - - auto on_complete = [&result, this](executor::ExecutionResult p_status, - std::vector &&values) { - this->p_status_ = p_status; - // TODO (Tianyi) I would make a decision on keeping one of p_status or - // error_message in my next PR - this->error_message_ = std::move(p_status.m_error_message); - result = std::move(values); - task_callback_(task_callback_arg_); - }; - - auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete] { - executor::PlanExecutor::ExecutePlan(plan, txn, params, result_format, - on_complete); - }); - - is_queuing_ = true; - - LOG_TRACE("Check Tcop_txn_state Size After ExecuteHelper %lu", - tcop_txn_state_.size()); - return p_status_; -} - -void TrafficCop::ExecuteStatementPlanGetResult() { - if (p_status_.m_result == ResultType::FAILURE) return; - - auto txn_result = GetCurrentTxnState().first->GetResult(); - if (single_statement_txn_ || txn_result == ResultType::FAILURE) { - LOG_TRACE("About to commit/abort: single stmt: %d,txn_result: %s", - single_statement_txn_, ResultTypeToString(txn_result).c_str()); - switch (txn_result) { - case ResultType::SUCCESS: - // Commit single statement - LOG_TRACE("Commit Transaction"); - p_status_.m_result = CommitQueryHelper(); - break; - - case ResultType::FAILURE: - default: - // Abort - LOG_TRACE("Abort Transaction"); - if (single_statement_txn_) { - LOG_TRACE("Tcop_txn_state size: %lu", tcop_txn_state_.size()); - p_status_.m_result = AbortQueryHelper(); - } else { - tcop_txn_state_.top().second = ResultType::ABORTED; - p_status_.m_result = ResultType::ABORTED; - } - } - } -} - -/* - * Prepare a statement based on parse tree. Begin a transaction if necessary. - * If the query is not issued in a transaction (if txn_stack is empty and it's - * not - * BEGIN query), Peloton will create a new transation for it. single_stmt - * transaction. - * Otherwise, it's a multi_stmt transaction. - * TODO(Yuchen): We do not need a query string to prepare a statement and the - * query string may - * contain the information of multiple statements rather than the single one. - * Hack here. We store - * the query string inside Statement objects for printing infomation. - */ -std::shared_ptr TrafficCop::PrepareStatement( - const std::string &stmt_name, const std::string &query_string, - std::unique_ptr sql_stmt_list, - const size_t thread_id UNUSED_ATTRIBUTE) { - LOG_TRACE("Prepare Statement query: %s", query_string.c_str()); - - // Empty statement - // TODO (Tianyi) Read through the parser code to see if this is appropriate - if (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0) { - // TODO (Tianyi) Do we need another query type called QUERY_EMPTY? - std::shared_ptr statement = - std::make_shared(stmt_name, QueryType::QUERY_INVALID, - query_string, std::move(sql_stmt_list)); - return statement; - } - - StatementType stmt_type = sql_stmt_list->GetStatement(0)->GetType(); - QueryType query_type = - StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); - std::shared_ptr statement = std::make_shared( - stmt_name, query_type, query_string, std::move(sql_stmt_list)); - - // We can learn transaction's states, BEGIN, COMMIT, ABORT, or ROLLBACK from - // member variables, tcop_txn_state_. We can also get single-statement txn or - // multi-statement txn from member variable single_statement_txn_ - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // --multi-statements except BEGIN in a transaction - if (!tcop_txn_state_.empty()) { - single_statement_txn_ = false; - // multi-statment txn has been aborted, just skip this query, - // and do not need to parse or execute this query anymore. - // Do not return nullptr in case that 'COMMIT' cannot be execute, - // because nullptr will directly return ResultType::FAILURE to - // packet_manager - if (tcop_txn_state_.top().second == ResultType::ABORTED) { - return statement; - } - } else { - // Begin new transaction when received single-statement query or "BEGIN" - // from multi-statement query - if (statement->GetQueryType() == - QueryType::QUERY_BEGIN) { // only begin a new transaction - // note this transaction is not single-statement transaction - LOG_TRACE("BEGIN"); - single_statement_txn_ = false; - } else { - // single statement - LOG_TRACE("SINGLE TXN"); - single_statement_txn_ = true; - } - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_TRACE("Begin txn failed"); - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - - if (settings::SettingsManager::GetBool(settings::SettingId::brain)) { - tcop_txn_state_.top().first->AddQueryString(query_string.c_str()); - } - - // TODO(Tianyi) Move Statement Planing into Statement's method - // to increase coherence - try { - // Run binder - auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, default_database_name_); - bind_node_visitor.BindNameToNode( - statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = optimizer_->BuildPelotonPlanTree( - statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); - statement->SetPlanTree(plan); - // Get the tables that our plan references so that we know how to - // invalidate it at a later point when the catalog changes - const std::set table_oids = - planner::PlanUtil::GetTablesReferenced(plan.get()); - statement->SetReferencedTables(table_oids); - - if (query_type == QueryType::QUERY_SELECT) { - auto tuple_descriptor = GenerateTupleDescriptor( - statement->GetStmtParseTreeList()->GetStatement(0)); - statement->SetTupleDescriptor(tuple_descriptor); - LOG_TRACE("select query, finish setting"); - } - } catch (Exception &e) { - error_message_ = e.what(); - ProcessInvalidStatement(); - return nullptr; - } - -#ifdef LOG_DEBUG_ENABLED - if (statement->GetPlanTree().get() != nullptr) { - LOG_TRACE("Statement Prepared: %s", statement->GetInfo().c_str()); - LOG_TRACE("%s", statement->GetPlanTree().get()->GetInfo().c_str()); - } -#endif - return statement; -} - -/* - * Do nothing if there is no active transaction; - * If single-stmt transaction, abort it; - * If multi-stmt transaction, just set transaction state to 'ABORTED'. - * The multi-stmt txn will be explicitly aborted when receiving 'Commit' or - * 'Rollback'. - */ -void TrafficCop::ProcessInvalidStatement() { - if (single_statement_txn_) { - LOG_TRACE("SINGLE ABORT!"); - AbortQueryHelper(); - } else { // multi-statment txn - if (tcop_txn_state_.top().second != ResultType::ABORTED) { - tcop_txn_state_.top().second = ResultType::ABORTED; - } - } -} - -bool TrafficCop::BindParamsForCachePlan( - const std::vector> - ¶meters, - const size_t thread_id UNUSED_ATTRIBUTE) { - if (tcop_txn_state_.empty()) { - single_statement_txn_ = true; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_ERROR("Begin txn failed"); - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - // Run binder - auto bind_node_visitor = binder::BindNodeVisitor(tcop_txn_state_.top().first, - default_database_name_); - - std::vector param_values; - for (const std::unique_ptr ¶m : - parameters) { - if (!expression::ExpressionUtil::IsValidStaticExpression(param.get())) { - error_message_ = "Invalid Expression Type"; - return false; - } - param->Accept(&bind_node_visitor); - // TODO(Yuchen): need better check for nullptr argument - param_values.push_back(param->Evaluate(nullptr, nullptr, nullptr)); - } - if (param_values.size() > 0) { - statement_->GetPlanTree()->SetParameterValues(¶m_values); - } - SetParamVal(param_values); - return true; -} - -void TrafficCop::GetTableColumns(parser::TableRef *from_table, - std::vector &target_columns) { - if (from_table == nullptr) return; - - // Query derived table - if (from_table->select != NULL) { - for (auto &expr : from_table->select->select_list) { - if (expr->GetExpressionType() == ExpressionType::STAR) - GetTableColumns(from_table->select->from_table.get(), target_columns); - else - target_columns.push_back(catalog::Column(expr->GetValueType(), 0, - expr->GetExpressionName())); - } - } else if (from_table->list.empty()) { - if (from_table->join == NULL) { - auto columns = - static_cast( - catalog::Catalog::GetInstance()->GetTableWithName( - from_table->GetDatabaseName(), from_table->GetSchemaName(), - from_table->GetTableName(), GetCurrentTxnState().first)) - ->GetSchema() - ->GetColumns(); - target_columns.insert(target_columns.end(), columns.begin(), - columns.end()); - } else { - GetTableColumns(from_table->join->left.get(), target_columns); - GetTableColumns(from_table->join->right.get(), target_columns); - } - } - // Query has multiple tables. Recursively add all tables - else { - for (auto &table : from_table->list) { - GetTableColumns(table.get(), target_columns); - } - } -} - -std::vector TrafficCop::GenerateTupleDescriptor( - parser::SQLStatement *sql_stmt) { - std::vector tuple_descriptor; - if (sql_stmt->GetType() != StatementType::SELECT) return tuple_descriptor; - auto select_stmt = (parser::SelectStatement *)sql_stmt; - - // TODO: this is a hack which I don't have time to fix now - // but it replaces a worse hack that was here before - // What should happen here is that plan nodes should store - // the schema of their expected results and here we should just read - // it and put it in the tuple descriptor - - // Get the columns information and set up - // the columns description for the returned results - // Set up the table - std::vector all_columns; - - // Check if query only has one Table - // Example : SELECT * FROM A; - GetTableColumns(select_stmt->from_table.get(), all_columns); - - int count = 0; - for (auto &expr : select_stmt->select_list) { - count++; - if (expr->GetExpressionType() == ExpressionType::STAR) { - for (auto column : all_columns) { - tuple_descriptor.push_back( - GetColumnFieldForValueType(column.GetName(), column.GetType())); - } - } else { - std::string col_name; - if (expr->alias.empty()) { - col_name = expr->expr_name_.empty() - ? std::string("expr") + std::to_string(count) - : expr->expr_name_; - } else { - col_name = expr->alias; - } - tuple_descriptor.push_back( - GetColumnFieldForValueType(col_name, expr->GetValueType())); - } - } - - return tuple_descriptor; -} - -// TODO: move it to postgres_protocal_handler.cpp -FieldInfo TrafficCop::GetColumnFieldForValueType(std::string column_name, - type::TypeId column_type) { - PostgresValueType field_type; - size_t field_size; - switch (column_type) { - case type::TypeId::BOOLEAN: - case type::TypeId::TINYINT: { - field_type = PostgresValueType::BOOLEAN; - field_size = 1; - break; - } - case type::TypeId::SMALLINT: { - field_type = PostgresValueType::SMALLINT; - field_size = 2; - break; - } - case type::TypeId::INTEGER: { - field_type = PostgresValueType::INTEGER; - field_size = 4; - break; - } - case type::TypeId::BIGINT: { - field_type = PostgresValueType::BIGINT; - field_size = 8; - break; - } - case type::TypeId::DECIMAL: { - field_type = PostgresValueType::DOUBLE; - field_size = 8; - break; - } - case type::TypeId::VARCHAR: - case type::TypeId::VARBINARY: { - field_type = PostgresValueType::TEXT; - field_size = 255; - break; - } - case type::TypeId::DATE: { - field_type = PostgresValueType::DATE; - field_size = 4; - break; - } - case type::TypeId::TIMESTAMP: { - field_type = PostgresValueType::TIMESTAMPS; - field_size = 64; // FIXME: Bytes??? - break; - } - default: { - // Type not Identified - LOG_ERROR("Unrecognized field type '%s' for field '%s'", - TypeIdToString(column_type).c_str(), column_name.c_str()); - field_type = PostgresValueType::TEXT; - field_size = 255; - break; - } - } - // HACK: Convert the type into a oid_t - // This ugly and I don't like it one bit... - return std::make_tuple(column_name, static_cast(field_type), - field_size); -} - -ResultType TrafficCop::ExecuteStatement( - const std::shared_ptr &statement, - const std::vector ¶ms, UNUSED_ATTRIBUTE bool unnamed, - std::shared_ptr param_stats, - const std::vector &result_format, std::vector &result, - size_t thread_id) { - // TODO(Tianyi) Further simplify this API - if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { - stats::BackendStatsContext::GetInstance()->InitQueryMetric( - statement, std::move(param_stats)); - } - - LOG_TRACE("Execute Statement of name: %s", - statement->GetStatementName().c_str()); - LOG_TRACE("Execute Statement of query: %s", - statement->GetQueryString().c_str()); - LOG_TRACE("Execute Statement Plan:\n%s", - planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - LOG_TRACE("Execute Statement Query Type: %s", - statement->GetQueryTypeString().c_str()); - LOG_TRACE("----QueryType: %d--------", - static_cast(statement->GetQueryType())); - - try { - switch (statement->GetQueryType()) { - case QueryType::QUERY_BEGIN: { - return BeginQueryHelper(thread_id); - } - case QueryType::QUERY_COMMIT: { - return CommitQueryHelper(); - } - case QueryType::QUERY_ROLLBACK: { - return AbortQueryHelper(); - } - default: - // The statement may be out of date - // It needs to be replan - if (statement->GetNeedsReplan()) { - // TODO(Tianyi) Move Statement Replan into Statement's method - // to increase coherence - auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, default_database_name_); - bind_node_visitor.BindNameToNode( - statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = optimizer_->BuildPelotonPlanTree( - statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); - statement->SetPlanTree(plan); - statement->SetNeedsReplan(true); - } - - ExecuteHelper(statement->GetPlanTree(), params, result, result_format, - thread_id); - if (GetQueuing()) { - return ResultType::QUEUING; - } else { - return ExecuteStatementGetResult(); - } - } - } catch (Exception &e) { - error_message_ = e.what(); - return ResultType::FAILURE; - } -} - -} // namespace tcop -} // namespace peloton diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index 9549c794a91..234eabf4374 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -23,7 +23,7 @@ #include "optimizer/rule.h" #include "parser/postgresparser.h" #include "planner/plan_util.h" -#include "traffic_cop/traffic_cop.h" +#include "traffic_cop/tcop.h" namespace peloton { From deaa03b3808b63d8254b5d5a2fcacd93c528fbc0 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 29 Jun 2018 14:52:05 -0400 Subject: [PATCH 28/48] Fix compilation --- src/include/network/marshal.h | 147 +------------------------------- src/network/marshal.cpp | 152 ++++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 143 deletions(-) diff --git a/src/include/network/marshal.h b/src/include/network/marshal.h index f46ef77605e..621c488b541 100644 --- a/src/include/network/marshal.h +++ b/src/include/network/marshal.h @@ -163,157 +163,18 @@ extern void GetStringToken(InputPacket *pkt, std::string &result); // TODO(Tianyu): These dumb things are here because copy_executor somehow calls // our network layer. This should NOT be the case. Will remove. -size_t OldReadParamType( - InputPacket *pkt, int num_params, std::vector ¶m_types) { - auto begin = pkt->ptr; - // get the type of each parameter - for (int i = 0; i < num_params; i++) { - int param_type = PacketGetInt(pkt, 4); - param_types[i] = param_type; - } - auto end = pkt->ptr; - return end - begin; -} +extern size_t OldReadParamType( + InputPacket *pkt, int num_params, std::vector ¶m_types); size_t OldReadParamFormat(InputPacket *pkt, int num_params_format, - std::vector &formats) { - auto begin = pkt->ptr; - // get the format of each parameter - for (int i = 0; i < num_params_format; i++) { - formats[i] = PacketGetInt(pkt, 2); - } - auto end = pkt->ptr; - return end - begin; -} + std::vector &formats); // For consistency, this function assumes the input vectors has the correct size size_t OldReadParamValue( InputPacket *pkt, int num_params, std::vector ¶m_types, std::vector> &bind_parameters, - std::vector ¶m_values, std::vector &formats) { - auto begin = pkt->ptr; - ByteBuf param; - for (int param_idx = 0; param_idx < num_params; param_idx++) { - int param_len = PacketGetInt(pkt, 4); - // BIND packet NULL parameter case - if (param_len == -1) { - // NULL mode - auto peloton_type = PostgresValueTypeToPelotonValueType( - static_cast(param_types[param_idx])); - bind_parameters[param_idx] = - std::make_pair(peloton_type, std::string("")); - param_values[param_idx] = - type::ValueFactory::GetNullValueByType(peloton_type); - } else { - PacketGetBytes(pkt, param_len, param); - - if (formats[param_idx] == 0) { - // TEXT mode - std::string param_str = std::string(std::begin(param), std::end(param)); - bind_parameters[param_idx] = - std::make_pair(type::TypeId::VARCHAR, param_str); - if ((unsigned int)param_idx >= param_types.size() || - PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx]) == - type::TypeId::VARCHAR) { - param_values[param_idx] = - type::ValueFactory::GetVarcharValue(param_str); - } else { - param_values[param_idx] = - (type::ValueFactory::GetVarcharValue(param_str)) - .CastAs(PostgresValueTypeToPelotonValueType( - (PostgresValueType)param_types[param_idx])); - } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != - type::TypeId::INVALID); - } else { - // BINARY mode - PostgresValueType pg_value_type = - static_cast(param_types[param_idx]); - LOG_TRACE("Postgres Protocol Conversion [param_idx=%d]", param_idx); - switch (pg_value_type) { - case PostgresValueType::TINYINT: { - int8_t int_val = 0; - for (size_t i = 0; i < sizeof(int8_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::TINYINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetTinyIntValue(int_val).Copy(); - break; - } - case PostgresValueType::SMALLINT: { - int16_t int_val = 0; - for (size_t i = 0; i < sizeof(int16_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetSmallIntValue(int_val).Copy(); - break; - } - case PostgresValueType::INTEGER: { - int32_t int_val = 0; - for (size_t i = 0; i < sizeof(int32_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::INTEGER, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetIntegerValue(int_val).Copy(); - break; - } - case PostgresValueType::BIGINT: { - int64_t int_val = 0; - for (size_t i = 0; i < sizeof(int64_t); ++i) { - int_val = (int_val << 8) | param[i]; - } - bind_parameters[param_idx] = - std::make_pair(type::TypeId::BIGINT, std::to_string(int_val)); - param_values[param_idx] = - type::ValueFactory::GetBigIntValue(int_val).Copy(); - break; - } - case PostgresValueType::DOUBLE: { - double float_val = 0; - unsigned long buf = 0; - for (size_t i = 0; i < sizeof(double); ++i) { - buf = (buf << 8) | param[i]; - } - PELOTON_MEMCPY(&float_val, &buf, sizeof(double)); - bind_parameters[param_idx] = std::make_pair( - type::TypeId::DECIMAL, std::to_string(float_val)); - param_values[param_idx] = - type::ValueFactory::GetDecimalValue(float_val).Copy(); - break; - } - case PostgresValueType::VARBINARY: { - bind_parameters[param_idx] = std::make_pair( - type::TypeId::VARBINARY, - std::string(reinterpret_cast(¶m[0]), param_len)); - param_values[param_idx] = type::ValueFactory::GetVarbinaryValue( - ¶m[0], param_len, true); - break; - } - default: { - LOG_ERROR( - "Binary Postgres protocol does not support data type '%s' [%d]", - PostgresValueTypeToString(pg_value_type).c_str(), - param_types[param_idx]); - break; - } - } - PELOTON_ASSERT(param_values[param_idx].GetTypeId() != - type::TypeId::INVALID); - } - } - } - auto end = pkt->ptr; - return end - begin; -} + std::vector ¶m_values, std::vector &formats); } // namespace network } // namespace peloton diff --git a/src/network/marshal.cpp b/src/network/marshal.cpp index 314dca1d5ea..6105daa5bba 100644 --- a/src/network/marshal.cpp +++ b/src/network/marshal.cpp @@ -162,5 +162,157 @@ void PacketPutCbytes(OutputPacket *pkt, const uchar *b, int len) { pkt->len += len; } +size_t OldReadParamType( + InputPacket *pkt, int num_params, std::vector ¶m_types) { + auto begin = pkt->ptr; + // get the type of each parameter + for (int i = 0; i < num_params; i++) { + int param_type = PacketGetInt(pkt, 4); + param_types[i] = param_type; + } + auto end = pkt->ptr; + return end - begin; +} + +size_t OldReadParamFormat(InputPacket *pkt, + int num_params_format, + std::vector &formats) { + auto begin = pkt->ptr; + // get the format of each parameter + for (int i = 0; i < num_params_format; i++) { + formats[i] = PacketGetInt(pkt, 2); + } + auto end = pkt->ptr; + return end - begin; +} + +// For consistency, this function assumes the input vectors has the correct size +size_t OldReadParamValue( + InputPacket *pkt, int num_params, std::vector ¶m_types, + std::vector> &bind_parameters, + std::vector ¶m_values, std::vector &formats) { + auto begin = pkt->ptr; + ByteBuf param; + for (int param_idx = 0; param_idx < num_params; param_idx++) { + int param_len = PacketGetInt(pkt, 4); + // BIND packet NULL parameter case + if (param_len == -1) { + // NULL mode + auto peloton_type = PostgresValueTypeToPelotonValueType( + static_cast(param_types[param_idx])); + bind_parameters[param_idx] = + std::make_pair(peloton_type, std::string("")); + param_values[param_idx] = + type::ValueFactory::GetNullValueByType(peloton_type); + } else { + PacketGetBytes(pkt, param_len, param); + + if (formats[param_idx] == 0) { + // TEXT mode + std::string param_str = std::string(std::begin(param), std::end(param)); + bind_parameters[param_idx] = + std::make_pair(type::TypeId::VARCHAR, param_str); + if ((unsigned int)param_idx >= param_types.size() || + PostgresValueTypeToPelotonValueType( + (PostgresValueType)param_types[param_idx]) == + type::TypeId::VARCHAR) { + param_values[param_idx] = + type::ValueFactory::GetVarcharValue(param_str); + } else { + param_values[param_idx] = + (type::ValueFactory::GetVarcharValue(param_str)) + .CastAs(PostgresValueTypeToPelotonValueType( + (PostgresValueType)param_types[param_idx])); + } + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); + } else { + // BINARY mode + PostgresValueType pg_value_type = + static_cast(param_types[param_idx]); + LOG_TRACE("Postgres Protocol Conversion [param_idx=%d]", param_idx); + switch (pg_value_type) { + case PostgresValueType::TINYINT: { + int8_t int_val = 0; + for (size_t i = 0; i < sizeof(int8_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::TINYINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetTinyIntValue(int_val).Copy(); + break; + } + case PostgresValueType::SMALLINT: { + int16_t int_val = 0; + for (size_t i = 0; i < sizeof(int16_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::SMALLINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetSmallIntValue(int_val).Copy(); + break; + } + case PostgresValueType::INTEGER: { + int32_t int_val = 0; + for (size_t i = 0; i < sizeof(int32_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::INTEGER, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetIntegerValue(int_val).Copy(); + break; + } + case PostgresValueType::BIGINT: { + int64_t int_val = 0; + for (size_t i = 0; i < sizeof(int64_t); ++i) { + int_val = (int_val << 8) | param[i]; + } + bind_parameters[param_idx] = + std::make_pair(type::TypeId::BIGINT, std::to_string(int_val)); + param_values[param_idx] = + type::ValueFactory::GetBigIntValue(int_val).Copy(); + break; + } + case PostgresValueType::DOUBLE: { + double float_val = 0; + unsigned long buf = 0; + for (size_t i = 0; i < sizeof(double); ++i) { + buf = (buf << 8) | param[i]; + } + PELOTON_MEMCPY(&float_val, &buf, sizeof(double)); + bind_parameters[param_idx] = std::make_pair( + type::TypeId::DECIMAL, std::to_string(float_val)); + param_values[param_idx] = + type::ValueFactory::GetDecimalValue(float_val).Copy(); + break; + } + case PostgresValueType::VARBINARY: { + bind_parameters[param_idx] = std::make_pair( + type::TypeId::VARBINARY, + std::string(reinterpret_cast(¶m[0]), param_len)); + param_values[param_idx] = type::ValueFactory::GetVarbinaryValue( + ¶m[0], param_len, true); + break; + } + default: { + LOG_ERROR( + "Binary Postgres protocol does not support data type '%s' [%d]", + PostgresValueTypeToString(pg_value_type).c_str(), + param_types[param_idx]); + break; + } + } + PELOTON_ASSERT(param_values[param_idx].GetTypeId() != + type::TypeId::INVALID); + } + } + } + auto end = pkt->ptr; + return end - begin; +} + } // namespace network } // namespace peloton From 86fc600c1ce1930542a9b7bce7b130f78e710534 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 29 Jun 2018 17:20:59 -0400 Subject: [PATCH 29/48] Fix test cases. There are still some issues with test failure. --- src/include/traffic_cop/tcop.h | 28 +- test/binder/binder_test.cpp | 28 +- test/codegen/update_translator_test.cpp | 1 - test/executor/copy_test.cpp | 33 +- test/executor/create_index_test.cpp | 82 ++--- test/executor/index_scan_test.cpp | 1 - test/executor/update_test.cpp | 131 ++++---- test/include/sql/testing_sql_util.h | 3 +- test/network/exception_test.cpp | 2 - test/network/prepare_stmt_test.cpp | 1 - test/network/select_all_test.cpp | 2 - test/network/simple_query_test.cpp | 2 - test/network/ssl_test.cpp | 2 - test/optimizer/old_optimizer_test.cpp | 425 ++++++++++++------------ test/optimizer/optimizer_test.cpp | 131 +++++--- test/sql/aggregate_sql_test.cpp | 1 - test/sql/testing_sql_util.cpp | 143 +++++--- test/statistics/stats_test.cpp | 1 - test/statistics/testing_stats_util.cpp | 151 ++++----- 19 files changed, 637 insertions(+), 531 deletions(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index be3a6007904..12a94bd6d58 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -18,6 +18,7 @@ #include "parser/postgresparser.h" #include "parser/sql_statement.h" #include "common/statement_cache.h" +#include "optimizer/optimizer.h" namespace peloton { namespace tcop { @@ -35,7 +36,7 @@ struct ClientProcessState { // This save currnet statement in the traffic cop std::shared_ptr statement_; // The optimizer used for this connection - std::unique_ptr optimizer_; + std::unique_ptr optimizer_{new optimizer::Optimizer()}; // flag of single statement txn bool single_statement_txn_ = false; std::vector result_format_; @@ -47,7 +48,7 @@ struct ClientProcessState { std::string skipped_query_string_; QueryType skipped_query_type_ = QueryType::QUERY_INVALID; StatementCache statement_cache_; - int rows_affected_; + int rows_affected_ = 0; executor::ExecutionResult p_status_; // TODO(Tianyu): This is vile, get rid of this @@ -59,6 +60,28 @@ struct ClientProcessState { } return tcop_txn_state_.top(); } + + // TODO(Tianyu): This is also vile, get rid of this. This is only used for testing + void Reset() { + thread_id_ = 0; + is_queuing_ = false; + error_message_ = ""; + db_name_ = DEFAULT_DB_NAME; + param_values_.clear(); + statement_.reset(); + optimizer_.reset(new optimizer::Optimizer()); + single_statement_txn_ = false; + result_format_.clear(); + result_.clear(); + tcop_txn_state_ = std::stack(); + txn_state_ = NetworkTransactionStateType::INVALID; + skipped_stmt_ = false; + skipped_query_string_ = ""; + skipped_query_type_ = QueryType::QUERY_INVALID; + statement_cache_.Clear(); + rows_affected_ = 0; + p_status_ = executor::ExecutionResult(); + } }; // TODO(Tianyu): We use an instance here in expectation that instance variables @@ -114,7 +137,6 @@ class Tcop { void ProcessInvalidStatement(ClientProcessState &state); - private: ResultType CommitQueryHelper(ClientProcessState &state); ResultType BeginQueryHelper(ClientProcessState &state); ResultType AbortQueryHelper(ClientProcessState &state); diff --git a/test/binder/binder_test.cpp b/test/binder/binder_test.cpp index b5b266c21cf..bef408159c4 100644 --- a/test/binder/binder_test.cpp +++ b/test/binder/binder_test.cpp @@ -22,11 +22,11 @@ #include "expression/tuple_value_expression.h" #include "optimizer/optimizer.h" #include "parser/postgresparser.h" -#include "traffic_cop/traffic_cop.h" #include "executor/testing_executor_util.h" #include "sql/testing_sql_util.h" #include "type/value_factory.h" +#include "traffic_cop/tcop.h" using std::make_shared; using std::make_tuple; @@ -60,10 +60,12 @@ void SetupTables(std::string database_name) { LOG_INFO("database %s created!", database_name.c_str()); auto &parser = parser::PostgresParser::GetInstance(); - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetDefaultDatabaseName(database_name); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; + state.db_name_ = database_name; + auto callback = [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }; optimizer::Optimizer optimizer; @@ -72,7 +74,7 @@ void SetupTables(std::string database_name) { for (auto &sql : createTableSQLs) { LOG_INFO("%s", sql.c_str()); txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); vector params; vector result; @@ -86,18 +88,18 @@ void SetupTables(std::string database_name) { statement->SetPlanTree( optimizer.BuildPelotonPlanTree(parse_tree_list, txn)); + state.statement_ = std::move(statement); TestingSQLUtil::counter_.store(1); - auto status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, - result, result_format); - if (traffic_cop.GetQueuing()) { + auto status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Table create result: %s", ResultTypeToString(status.m_result).c_str()); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); } } diff --git a/test/codegen/update_translator_test.cpp b/test/codegen/update_translator_test.cpp index b14f4506384..a4b7abc60f1 100644 --- a/test/codegen/update_translator_test.cpp +++ b/test/codegen/update_translator_test.cpp @@ -26,7 +26,6 @@ #include "planner/create_plan.h" #include "planner/seq_scan_plan.h" #include "planner/plan_util.h" -#include "traffic_cop/traffic_cop.h" namespace peloton { namespace test { diff --git a/test/executor/copy_test.cpp b/test/executor/copy_test.cpp index 9b9291f4111..7e948c17e8b 100644 --- a/test/executor/copy_test.cpp +++ b/test/executor/copy_test.cpp @@ -25,7 +25,6 @@ #include "optimizer/rule.h" #include "parser/postgresparser.h" #include "planner/seq_scan_plan.h" -#include "traffic_cop/traffic_cop.h" #include "gtest/gtest.h" #include "statistics/testing_stats_util.h" @@ -49,14 +48,16 @@ TEST_F(CopyTests, Copying) { std::unique_ptr optimizer( new optimizer::Optimizer); - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); + auto &traffic_cop = tcop::Tcop::GetInstance(); + auto callback = [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }; + tcop::ClientProcessState state; // Create a table without primary key TestingStatsUtil::CreateTable(false); txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); std::string short_string = "eeeeeeeeee"; std::string long_string = short_string + short_string + short_string + short_string + short_string + @@ -89,18 +90,18 @@ TEST_F(CopyTests, Copying) { // Execute insert auto statement = TestingStatsUtil::GetInsertStmt(12345, insert_str); std::vector params; - std::vector result_format(statement->GetTupleDescriptor().size(), 0); - std::vector result; - + std::vector result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - executor::ExecutionResult status = traffic_cop.ExecuteHelper( - statement->GetPlanTree(), params, result, result_format); - - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + executor::ExecutionResult status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } EXPECT_EQ(status.m_result, peloton::ResultType::SUCCESS); @@ -108,7 +109,7 @@ TEST_F(CopyTests, Copying) { ResultTypeToString(status.m_result).c_str()); } LOG_TRACE("Tuples inserted!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // Now Copying end-to-end LOG_TRACE("Copying a table..."); diff --git a/test/executor/create_index_test.cpp b/test/executor/create_index_test.cpp index 9d53642935e..658d1067164 100644 --- a/test/executor/create_index_test.cpp +++ b/test/executor/create_index_test.cpp @@ -12,7 +12,6 @@ #include #include "sql/testing_sql_util.h" -#include "traffic_cop/traffic_cop.h" #include "binder/bind_node_visitor.h" #include "catalog/catalog.h" @@ -32,7 +31,6 @@ #include "planner/insert_plan.h" #include "planner/plan_util.h" #include "planner/update_plan.h" -#include "traffic_cop/traffic_cop.h" #include "gtest/gtest.h" @@ -56,18 +54,20 @@ TEST_F(CreateIndexTests, CreatingIndex) { std::unique_ptr optimizer; optimizer.reset(new optimizer::Optimizer); - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); + auto &traffic_cop = tcop::Tcop::GetInstance(); + auto callback = [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }; // Create a table first txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + tcop::ClientProcessState state; + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Creating table"); LOG_INFO( "Query: CREATE TABLE department_table(dept_id INT PRIMARY KEY,student_id " "INT, dept_name TEXT);"); - std::unique_ptr statement; + std::shared_ptr statement; statement.reset(new Statement("CREATE", "CREATE TABLE department_table(dept_id INT " "PRIMARY KEY, student_id INT, dept_name " @@ -95,28 +95,31 @@ TEST_F(CreateIndexTests, CreatingIndex) { std::vector result; LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - std::vector result_format; - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + std::vector result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - executor::ExecutionResult status = traffic_cop.ExecuteHelper( - statement->GetPlanTree(), params, result, result_format); + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + executor::ExecutionResult status = traffic_cop.ExecuteHelper(state, callback); - if (traffic_cop.GetQueuing()) { + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Table Created"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); txn = txn_manager.BeginTransaction(); // Inserting a tuple end-to-end - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Inserting a tuple..."); LOG_INFO( "Query: INSERT INTO department_table(dept_id,student_id ,dept_name) " @@ -144,26 +147,28 @@ TEST_F(CreateIndexTests, CreatingIndex) { planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); LOG_INFO("Executing plan..."); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); - + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple inserted!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // Now Updating end-to-end txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Creating and Index"); LOG_INFO("Query: CREATE INDEX saif ON department_table (student_id);"); statement.reset(new Statement( @@ -186,22 +191,23 @@ TEST_F(CreateIndexTests, CreatingIndex) { planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); LOG_INFO("Executing plan..."); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); - + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("INDEX CREATED!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); txn = txn_manager.BeginTransaction(); auto target_table_ = catalog::Catalog::GetInstance()->GetTableWithName( diff --git a/test/executor/index_scan_test.cpp b/test/executor/index_scan_test.cpp index c22f22bcb89..8be5e406ce4 100644 --- a/test/executor/index_scan_test.cpp +++ b/test/executor/index_scan_test.cpp @@ -32,7 +32,6 @@ #include "planner/index_scan_plan.h" #include "planner/insert_plan.h" #include "storage/data_table.h" -#include "traffic_cop/traffic_cop.h" #include "type/value_factory.h" using ::testing::NotNull; diff --git a/test/executor/update_test.cpp b/test/executor/update_test.cpp index faa5526efea..9f0f16d1c1f 100644 --- a/test/executor/update_test.cpp +++ b/test/executor/update_test.cpp @@ -46,7 +46,6 @@ #include "planner/update_plan.h" #include "storage/data_table.h" #include "storage/tile_group_factory.h" -#include "traffic_cop/traffic_cop.h" #include "type/value.h" #include "type/value_factory.h" @@ -164,9 +163,12 @@ TEST_F(UpdateTests, UpdatingOld) { std::unique_ptr optimizer( new optimizer::Optimizer); - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); + auto &traffic_cop = tcop::Tcop::GetInstance(); + auto callback = [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }; + tcop::ClientProcessState state; + // Create a table first LOG_INFO("Creating a table..."); auto id_column = catalog::Column( @@ -199,17 +201,17 @@ TEST_F(UpdateTests, UpdatingOld) { // Inserting a tuple end-to-end txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Inserting a tuple..."); LOG_INFO( "Query: INSERT INTO department_table(dept_id,manager_id,dept_name) " "VALUES (1,12,'hello_1');"); - std::unique_ptr statement; - statement.reset(new Statement("INSERT", - "INSERT INTO " - "department_table(dept_id,manager_id,dept_name)" - " VALUES (1,12,'hello_1');")); + auto statement = std::make_shared( + "INSERT", + "INSERT INTO " + "department_table(dept_id,manager_id,dept_name)" + " VALUES (1,12,'hello_1');"); auto &peloton_parser = parser::PostgresParser::GetInstance(); LOG_INFO("Building parse tree..."); auto insert_stmt = peloton_parser.BuildParseTree( @@ -228,31 +230,33 @@ TEST_F(UpdateTests, UpdatingOld) { statement->SetPlanTree(optimizer->BuildPelotonPlanTree(insert_stmt, txn)); LOG_INFO("Building plan tree completed!"); std::vector params; - std::vector result; LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - std::vector result_format; - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + std::vector result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - executor::ExecutionResult status = traffic_cop.ExecuteHelper( - statement->GetPlanTree(), params, result, result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + executor::ExecutionResult status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple inserted!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); LOG_INFO("%s", table->GetInfo().c_str()); // Now Updating end-to-end txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Updating a tuple..."); LOG_INFO( @@ -277,25 +281,29 @@ TEST_F(UpdateTests, UpdatingOld) { LOG_INFO("Building plan tree completed!"); LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple Updated!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); LOG_INFO("%s", table->GetInfo().c_str()); txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Updating another tuple..."); LOG_INFO( @@ -322,25 +330,29 @@ TEST_F(UpdateTests, UpdatingOld) { LOG_INFO("Building plan tree completed!"); LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple Updated!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); LOG_INFO("%s", table->GetInfo().c_str()); txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Updating primary key..."); LOG_INFO("Query: UPDATE department_table SET dept_id = 2 WHERE dept_id = 1"); statement.reset(new Statement( @@ -361,26 +373,30 @@ TEST_F(UpdateTests, UpdatingOld) { LOG_INFO("Building plan tree completed!"); LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple Updated!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); LOG_INFO("%s", table->GetInfo().c_str()); // Deleting now txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Deleting a tuple..."); LOG_INFO("Query: DELETE FROM department_table WHERE dept_name = 'CS'"); @@ -403,20 +419,23 @@ TEST_F(UpdateTests, UpdatingOld) { LOG_INFO("Building plan tree completed!"); LOG_INFO("Executing plan...\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, callback); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple deleted!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // free the database just created txn = txn_manager.BeginTransaction(); diff --git a/test/include/sql/testing_sql_util.h b/test/include/sql/testing_sql_util.h index 762ffff4f89..87f2a97e0a8 100644 --- a/test/include/sql/testing_sql_util.h +++ b/test/include/sql/testing_sql_util.h @@ -15,7 +15,7 @@ #include #include "common/statement.h" -#include "traffic_cop/traffic_cop.h" +#include "traffic_cop/tcop.h" namespace peloton { @@ -95,7 +95,6 @@ class TestingSQLUtil { static int GetRandomInteger(const int lower_bound, const int upper_bound); static void UtilTestTaskCallback(void *arg); - static tcop::TrafficCop traffic_cop_; static std::atomic_int counter_; // inline static void SetTrafficCopCounter() { // counter_.store(1); diff --git a/test/network/exception_test.cpp b/test/network/exception_test.cpp index 08ecc98c9a5..3120f79e063 100644 --- a/test/network/exception_test.cpp +++ b/test/network/exception_test.cpp @@ -18,8 +18,6 @@ #include "gtest/gtest.h" #include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" -#include "network/postgres_protocol_handler.h" -#include "network/protocol_handler_factory.h" #include "util/string_util.h" namespace peloton { diff --git a/test/network/prepare_stmt_test.cpp b/test/network/prepare_stmt_test.cpp index 07e7ebb76c8..cccd19abd78 100644 --- a/test/network/prepare_stmt_test.cpp +++ b/test/network/prepare_stmt_test.cpp @@ -15,7 +15,6 @@ #include "common/logger.h" #include "gtest/gtest.h" #include "network/peloton_server.h" -#include "network/postgres_protocol_handler.h" #include "util/string_util.h" #include "network/network_io_wrapper_factory.h" diff --git a/test/network/select_all_test.cpp b/test/network/select_all_test.cpp index 1f5552b7aa9..9c50b650425 100644 --- a/test/network/select_all_test.cpp +++ b/test/network/select_all_test.cpp @@ -14,11 +14,9 @@ #include "gtest/gtest.h" #include "common/logger.h" #include "network/peloton_server.h" -#include "network/protocol_handler_factory.h" #include "network/network_io_wrapper_factory.h" #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ -#include "network/postgres_protocol_handler.h" namespace peloton { namespace test { diff --git a/test/network/simple_query_test.cpp b/test/network/simple_query_test.cpp index 8e2409f2621..94b1a735bd2 100644 --- a/test/network/simple_query_test.cpp +++ b/test/network/simple_query_test.cpp @@ -14,10 +14,8 @@ #include "gtest/gtest.h" #include "common/logger.h" #include "network/peloton_server.h" -#include "network/protocol_handler_factory.h" #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ -#include "network/postgres_protocol_handler.h" #include "network/network_io_wrapper_factory.h" #define NUM_THREADS 1 diff --git a/test/network/ssl_test.cpp b/test/network/ssl_test.cpp index b9399ce7757..033b641494a 100644 --- a/test/network/ssl_test.cpp +++ b/test/network/ssl_test.cpp @@ -16,8 +16,6 @@ #include "gtest/gtest.h" #include "network/network_io_wrapper_factory.h" #include "network/peloton_server.h" -#include "network/postgres_protocol_handler.h" -#include "network/protocol_handler_factory.h" #include "peloton_config.h" #include "util/string_util.h" diff --git a/test/optimizer/old_optimizer_test.cpp b/test/optimizer/old_optimizer_test.cpp index 92949cc8521..f2ffb0af31c 100644 --- a/test/optimizer/old_optimizer_test.cpp +++ b/test/optimizer/old_optimizer_test.cpp @@ -21,7 +21,6 @@ #include "planner/plan_util.h" #include "planner/update_plan.h" #include "sql/testing_sql_util.h" -#include "traffic_cop/traffic_cop.h" namespace peloton { namespace test { @@ -38,218 +37,218 @@ using namespace optimizer; class OldOptimizerTests : public PelotonTest {}; -// Test whether update stament will use index scan plan -// TODO: Split the tests into separate test cases. -TEST_F(OldOptimizerTests, UpdateDelWithIndexScanTest) { - LOG_TRACE("Bootstrapping..."); - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto txn = txn_manager.BeginTransaction(); - catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn); - txn_manager.CommitTransaction(txn); - - LOG_TRACE("Bootstrapping completed!"); - - optimizer::Optimizer optimizer; - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); - - // Create a table first - txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); - - LOG_TRACE("Creating table"); - LOG_TRACE( - "Query: CREATE TABLE department_table(dept_id INT PRIMARY KEY,student_id " - "INT, dept_name TEXT);"); - std::unique_ptr statement; - statement.reset(new Statement("CREATE", - "CREATE TABLE department_table(dept_id INT " - "PRIMARY KEY, student_id INT, dept_name " - "TEXT);")); - - auto &peloton_parser = parser::PostgresParser::GetInstance(); - - auto create_stmt = peloton_parser.BuildParseTree( - "CREATE TABLE department_table(dept_id INT PRIMARY KEY, student_id INT, " - "dept_name TEXT);"); - - auto parse_tree = create_stmt->GetStatement(0); - auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - statement->SetPlanTree(optimizer.BuildPelotonPlanTree(create_stmt, txn)); - - std::vector params; - std::vector result; - LOG_TRACE("Query Plan:\n%s", - planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - std::vector result_format; - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); - TestingSQLUtil::counter_.store(1); - executor::ExecutionResult status = traffic_cop.ExecuteHelper( - statement->GetPlanTree(), params, result, result_format); - if (traffic_cop.GetQueuing()) { - TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); - } - LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(status.m_result).c_str()); - LOG_TRACE("Table Created"); - traffic_cop.CommitQueryHelper(); - - txn = txn_manager.BeginTransaction(); - // Inserting a tuple end-to-end - traffic_cop.SetTcopTxnState(txn); - LOG_TRACE("Inserting a tuple..."); - LOG_TRACE( - "Query: INSERT INTO department_table(dept_id,student_id ,dept_name) " - "VALUES (1,52,'hello_1');"); - statement.reset(new Statement("INSERT", - "INSERT INTO department_table(dept_id, " - "student_id, dept_name) VALUES " - "(1,52,'hello_1');")); - - auto insert_stmt = peloton_parser.BuildParseTree( - "INSERT INTO department_table(dept_id,student_id,dept_name) VALUES " - "(1,52,'hello_1');"); - - parse_tree = insert_stmt->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - statement->SetPlanTree(optimizer.BuildPelotonPlanTree(insert_stmt, txn)); - - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); - TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { - TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); - } - LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(status.m_result).c_str()); - LOG_TRACE("Tuple inserted!"); - traffic_cop.CommitQueryHelper(); - - // Now Create index - txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); - LOG_TRACE("Creating and Index"); - LOG_TRACE("Query: CREATE INDEX saif ON department_table (student_id);"); - statement.reset(new Statement( - "CREATE", "CREATE INDEX saif ON department_table (student_id);")); - - auto update_stmt = peloton_parser.BuildParseTree( - "CREATE INDEX saif ON department_table (student_id);"); - - parse_tree = update_stmt->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - statement->SetPlanTree(optimizer.BuildPelotonPlanTree(update_stmt, txn)); - - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); - TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { - TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); - } - LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(status.m_result).c_str()); - LOG_TRACE("INDEX CREATED!"); - traffic_cop.CommitQueryHelper(); - - txn = txn_manager.BeginTransaction(); - auto target_table_ = catalog::Catalog::GetInstance()->GetTableWithName( - DEFAULT_DB_NAME, DEFAULT_SCHEMA_NAME, "department_table", txn); - // Expected 1 , Primary key index + created index - EXPECT_EQ(target_table_->GetIndexCount(), 2); - txn_manager.CommitTransaction(txn); - - txn = txn_manager.BeginTransaction(); - // Test update tuple with index scan - LOG_TRACE("Updating a tuple..."); - LOG_TRACE( - "Query: UPDATE department_table SET dept_name = 'CS' WHERE student_id = " - "52"); - update_stmt = peloton_parser.BuildParseTree( - "UPDATE department_table SET dept_name = 'CS' WHERE student_id = 52"); - - parse_tree = update_stmt->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - auto update_plan = optimizer.BuildPelotonPlanTree(update_stmt, txn); - txn_manager.CommitTransaction(txn); - - txn = txn_manager.BeginTransaction(); - // Check scan plan - ASSERT_FALSE(update_plan == nullptr); - EXPECT_EQ(update_plan->GetPlanNodeType(), PlanNodeType::UPDATE); - auto &update_scan_plan = update_plan->GetChildren().front(); - EXPECT_EQ(update_scan_plan->GetPlanNodeType(), PlanNodeType::INDEXSCAN); - - update_stmt = peloton_parser.BuildParseTree( - "UPDATE department_table SET dept_name = 'CS' WHERE dept_name = 'CS'"); - - parse_tree = update_stmt->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - update_plan = optimizer.BuildPelotonPlanTree(update_stmt, txn); - EXPECT_EQ(update_plan->GetChildren().front()->GetPlanNodeType(), - PlanNodeType::SEQSCAN); - txn_manager.CommitTransaction(txn); - - txn = txn_manager.BeginTransaction(); - // Test delete tuple with index scan - LOG_TRACE("Deleting a tuple..."); - LOG_TRACE("Query: DELETE FROM department_table WHERE student_id = 52"); - auto delete_stmt = peloton_parser.BuildParseTree( - "DELETE FROM department_table WHERE student_id = 52"); - - parse_tree = delete_stmt->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - auto del_plan = optimizer.BuildPelotonPlanTree(delete_stmt, txn); - txn_manager.CommitTransaction(txn); - - // Check scan plan - EXPECT_EQ(del_plan->GetPlanNodeType(), PlanNodeType::DELETE); - auto &del_scan_plan = del_plan->GetChildren().front(); - EXPECT_EQ(del_scan_plan->GetPlanNodeType(), PlanNodeType::INDEXSCAN); - del_plan = nullptr; - - txn = txn_manager.BeginTransaction(); - // Test delete tuple with seq scan - auto delete_stmt_seq = peloton_parser.BuildParseTree( - "DELETE FROM department_table WHERE dept_name = 'CS'"); - - parse_tree = delete_stmt_seq->GetStatement(0); - bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); - bind_node_visitor.BindNameToNode(parse_tree); - - auto del_plan_seq = optimizer.BuildPelotonPlanTree(delete_stmt_seq, txn); - auto &del_scan_plan_seq = del_plan_seq->GetChildren().front(); - txn_manager.CommitTransaction(txn); - EXPECT_EQ(del_scan_plan_seq->GetPlanNodeType(), PlanNodeType::SEQSCAN); - - // free the database just created - txn = txn_manager.BeginTransaction(); - catalog::Catalog::GetInstance()->DropDatabaseWithName(DEFAULT_DB_NAME, txn); - txn_manager.CommitTransaction(txn); -} +//// Test whether update stament will use index scan plan +//// TODO: Split the tests into separate test cases. +//TEST_F(OldOptimizerTests, UpdateDelWithIndexScanTest) { +// LOG_TRACE("Bootstrapping..."); +// auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); +// auto txn = txn_manager.BeginTransaction(); +// catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn); +// txn_manager.CommitTransaction(txn); +// +// LOG_TRACE("Bootstrapping completed!"); +// +// optimizer::Optimizer optimizer; +// auto &traffic_cop = tcop::TrafficCop::GetInstance(); +// traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, +// &TestingSQLUtil::counter_); +// +// // Create a table first +// txn = txn_manager.BeginTransaction(); +// traffic_cop.SetTcopTxnState(txn); +// +// LOG_TRACE("Creating table"); +// LOG_TRACE( +// "Query: CREATE TABLE department_table(dept_id INT PRIMARY KEY,student_id " +// "INT, dept_name TEXT);"); +// std::unique_ptr statement; +// statement.reset(new Statement("CREATE", +// "CREATE TABLE department_table(dept_id INT " +// "PRIMARY KEY, student_id INT, dept_name " +// "TEXT);")); +// +// auto &peloton_parser = parser::PostgresParser::GetInstance(); +// +// auto create_stmt = peloton_parser.BuildParseTree( +// "CREATE TABLE department_table(dept_id INT PRIMARY KEY, student_id INT, " +// "dept_name TEXT);"); +// +// auto parse_tree = create_stmt->GetStatement(0); +// auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// statement->SetPlanTree(optimizer.BuildPelotonPlanTree(create_stmt, txn)); +// +// std::vector params; +// std::vector result; +// LOG_TRACE("Query Plan:\n%s", +// planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); +// std::vector result_format; +// result_format = std::vector(statement->GetTupleDescriptor().size(), 0); +// TestingSQLUtil::counter_.store(1); +// executor::ExecutionResult status = traffic_cop.ExecuteHelper( +// statement->GetPlanTree(), params, result, result_format); +// if (traffic_cop.GetQueuing()) { +// TestingSQLUtil::ContinueAfterComplete(); +// traffic_cop.ExecuteStatementPlanGetResult(); +// status = traffic_cop.p_status_; +// traffic_cop.SetQueuing(false); +// } +// LOG_TRACE("Statement executed. Result: %s", +// ResultTypeToString(status.m_result).c_str()); +// LOG_TRACE("Table Created"); +// traffic_cop.CommitQueryHelper(); +// +// txn = txn_manager.BeginTransaction(); +// // Inserting a tuple end-to-end +// traffic_cop.SetTcopTxnState(txn); +// LOG_TRACE("Inserting a tuple..."); +// LOG_TRACE( +// "Query: INSERT INTO department_table(dept_id,student_id ,dept_name) " +// "VALUES (1,52,'hello_1');"); +// statement.reset(new Statement("INSERT", +// "INSERT INTO department_table(dept_id, " +// "student_id, dept_name) VALUES " +// "(1,52,'hello_1');")); +// +// auto insert_stmt = peloton_parser.BuildParseTree( +// "INSERT INTO department_table(dept_id,student_id,dept_name) VALUES " +// "(1,52,'hello_1');"); +// +// parse_tree = insert_stmt->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// statement->SetPlanTree(optimizer.BuildPelotonPlanTree(insert_stmt, txn)); +// +// result_format = std::vector(statement->GetTupleDescriptor().size(), 0); +// TestingSQLUtil::counter_.store(1); +// status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, +// result_format); +// if (traffic_cop.GetQueuing()) { +// TestingSQLUtil::ContinueAfterComplete(); +// traffic_cop.ExecuteStatementPlanGetResult(); +// status = traffic_cop.p_status_; +// traffic_cop.SetQueuing(false); +// } +// LOG_TRACE("Statement executed. Result: %s", +// ResultTypeToString(status.m_result).c_str()); +// LOG_TRACE("Tuple inserted!"); +// traffic_cop.CommitQueryHelper(); +// +// // Now Create index +// txn = txn_manager.BeginTransaction(); +// traffic_cop.SetTcopTxnState(txn); +// LOG_TRACE("Creating and Index"); +// LOG_TRACE("Query: CREATE INDEX saif ON department_table (student_id);"); +// statement.reset(new Statement( +// "CREATE", "CREATE INDEX saif ON department_table (student_id);")); +// +// auto update_stmt = peloton_parser.BuildParseTree( +// "CREATE INDEX saif ON department_table (student_id);"); +// +// parse_tree = update_stmt->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// statement->SetPlanTree(optimizer.BuildPelotonPlanTree(update_stmt, txn)); +// +// result_format = std::vector(statement->GetTupleDescriptor().size(), 0); +// TestingSQLUtil::counter_.store(1); +// status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, +// result_format); +// if (traffic_cop.GetQueuing()) { +// TestingSQLUtil::ContinueAfterComplete(); +// traffic_cop.ExecuteStatementPlanGetResult(); +// status = traffic_cop.p_status_; +// traffic_cop.SetQueuing(false); +// } +// LOG_TRACE("Statement executed. Result: %s", +// ResultTypeToString(status.m_result).c_str()); +// LOG_TRACE("INDEX CREATED!"); +// traffic_cop.CommitQueryHelper(); +// +// txn = txn_manager.BeginTransaction(); +// auto target_table_ = catalog::Catalog::GetInstance()->GetTableWithName( +// DEFAULT_DB_NAME, DEFAULT_SCHEMA_NAME, "department_table", txn); +// // Expected 1 , Primary key index + created index +// EXPECT_EQ(target_table_->GetIndexCount(), 2); +// txn_manager.CommitTransaction(txn); +// +// txn = txn_manager.BeginTransaction(); +// // Test update tuple with index scan +// LOG_TRACE("Updating a tuple..."); +// LOG_TRACE( +// "Query: UPDATE department_table SET dept_name = 'CS' WHERE student_id = " +// "52"); +// update_stmt = peloton_parser.BuildParseTree( +// "UPDATE department_table SET dept_name = 'CS' WHERE student_id = 52"); +// +// parse_tree = update_stmt->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// auto update_plan = optimizer.BuildPelotonPlanTree(update_stmt, txn); +// txn_manager.CommitTransaction(txn); +// +// txn = txn_manager.BeginTransaction(); +// // Check scan plan +// ASSERT_FALSE(update_plan == nullptr); +// EXPECT_EQ(update_plan->GetPlanNodeType(), PlanNodeType::UPDATE); +// auto &update_scan_plan = update_plan->GetChildren().front(); +// EXPECT_EQ(update_scan_plan->GetPlanNodeType(), PlanNodeType::INDEXSCAN); +// +// update_stmt = peloton_parser.BuildParseTree( +// "UPDATE department_table SET dept_name = 'CS' WHERE dept_name = 'CS'"); +// +// parse_tree = update_stmt->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// update_plan = optimizer.BuildPelotonPlanTree(update_stmt, txn); +// EXPECT_EQ(update_plan->GetChildren().front()->GetPlanNodeType(), +// PlanNodeType::SEQSCAN); +// txn_manager.CommitTransaction(txn); +// +// txn = txn_manager.BeginTransaction(); +// // Test delete tuple with index scan +// LOG_TRACE("Deleting a tuple..."); +// LOG_TRACE("Query: DELETE FROM department_table WHERE student_id = 52"); +// auto delete_stmt = peloton_parser.BuildParseTree( +// "DELETE FROM department_table WHERE student_id = 52"); +// +// parse_tree = delete_stmt->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// auto del_plan = optimizer.BuildPelotonPlanTree(delete_stmt, txn); +// txn_manager.CommitTransaction(txn); +// +// // Check scan plan +// EXPECT_EQ(del_plan->GetPlanNodeType(), PlanNodeType::DELETE); +// auto &del_scan_plan = del_plan->GetChildren().front(); +// EXPECT_EQ(del_scan_plan->GetPlanNodeType(), PlanNodeType::INDEXSCAN); +// del_plan = nullptr; +// +// txn = txn_manager.BeginTransaction(); +// // Test delete tuple with seq scan +// auto delete_stmt_seq = peloton_parser.BuildParseTree( +// "DELETE FROM department_table WHERE dept_name = 'CS'"); +// +// parse_tree = delete_stmt_seq->GetStatement(0); +// bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); +// bind_node_visitor.BindNameToNode(parse_tree); +// +// auto del_plan_seq = optimizer.BuildPelotonPlanTree(delete_stmt_seq, txn); +// auto &del_scan_plan_seq = del_plan_seq->GetChildren().front(); +// txn_manager.CommitTransaction(txn); +// EXPECT_EQ(del_scan_plan_seq->GetPlanNodeType(), PlanNodeType::SEQSCAN); +// +// // free the database just created +// txn = txn_manager.BeginTransaction(); +// catalog::Catalog::GetInstance()->DropDatabaseWithName(DEFAULT_DB_NAME, txn); +// txn_manager.CommitTransaction(txn); +//} } // namespace test } // namespace peloton diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 50696017bb5..29ccc3ba049 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include #include "common/harness.h" #include "binder/bind_node_visitor.h" @@ -36,7 +37,6 @@ #include "planner/seq_scan_plan.h" #include "planner/update_plan.h" #include "sql/testing_sql_util.h" -#include "traffic_cop/traffic_cop.h" namespace peloton { namespace test { @@ -80,18 +80,17 @@ TEST_F(OptimizerTests, HashJoinTest) { LOG_INFO("Bootstrapping completed!"); optimizer::Optimizer optimizer; - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - traffic_cop.SetTaskCallback(TestingSQLUtil::UtilTestTaskCallback, - &TestingSQLUtil::counter_); + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; // Create a table first txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + LOG_INFO("Creating table"); LOG_INFO("Query: CREATE TABLE table_a(aid INT PRIMARY KEY,value INT);"); - std::unique_ptr statement; - statement.reset(new Statement( - "CREATE", "CREATE TABLE table_a(aid INT PRIMARY KEY,value INT);")); + auto statement = std::make_shared( + "CREATE", "CREATE TABLE table_a(aid INT PRIMARY KEY,value INT);"); auto &peloton_parser = parser::PostgresParser::GetInstance(); @@ -103,22 +102,25 @@ TEST_F(OptimizerTests, HashJoinTest) { statement->SetPlanTree(optimizer.BuildPelotonPlanTree(create_stmt, txn)); std::vector params; - std::vector result; - std::vector result_format; - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + std::vector result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - executor::ExecutionResult status = traffic_cop.ExecuteHelper( - statement->GetPlanTree(), params, result, result_format); - if (traffic_cop.GetQueuing()) { + state.statement_.reset(statement.get()); + state.param_values_ = params; + state.result_format_ = result_format; + executor::ExecutionResult status = traffic_cop.ExecuteHelper(state, [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Table Created"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // NOTE: everytime we create a database, there will be 9 catalog tables inside // Additionally, we also created a table for the test. @@ -129,7 +131,8 @@ TEST_F(OptimizerTests, HashJoinTest) { ->GetTableCount(), expected_table_count); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Creating table"); LOG_INFO("Query: CREATE TABLE table_b(bid INT PRIMARY KEY,value INT);"); statement.reset(new Statement( @@ -143,20 +146,25 @@ TEST_F(OptimizerTests, HashJoinTest) { statement->SetPlanTree(optimizer.BuildPelotonPlanTree(create_stmt, txn)); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Table Created"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // Account for table created. expected_table_count++; @@ -166,8 +174,9 @@ TEST_F(OptimizerTests, HashJoinTest) { ->GetTableCount(), expected_table_count); + state.Reset(); // Inserting a tuple to table_a - traffic_cop.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Inserting a tuple..."); LOG_INFO("Query: INSERT INTO table_a(aid, value) VALUES (1,1);"); statement.reset(new Statement( @@ -181,24 +190,30 @@ TEST_F(OptimizerTests, HashJoinTest) { statement->SetPlanTree(optimizer.BuildPelotonPlanTree(insert_stmt, txn)); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple inserted to table_a!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); // Inserting a tuple to table_b txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Inserting a tuple..."); LOG_INFO("Query: INSERT INTO table_b(bid, value) VALUES (1,2);"); statement.reset(new Statement( @@ -211,23 +226,29 @@ TEST_F(OptimizerTests, HashJoinTest) { statement->SetPlanTree(optimizer.BuildPelotonPlanTree(insert_stmt, txn)); - result_format = std::vector(statement->GetTupleDescriptor().size(), 0); + result_format = std::vector(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + state.statement_ = statement; + state.param_values_ = params; + state.result_format_ = result_format; + status = traffic_cop.ExecuteHelper(state, [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Tuple inserted to table_b!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); txn = txn_manager.BeginTransaction(); - traffic_cop.SetTcopTxnState(txn); + state.Reset(); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); LOG_INFO("Join ..."); LOG_INFO("Query: SELECT * FROM table_a INNER JOIN table_b ON aid = bid;"); statement.reset(new Statement( @@ -240,20 +261,22 @@ TEST_F(OptimizerTests, HashJoinTest) { statement->SetPlanTree(optimizer.BuildPelotonPlanTree(select_stmt, txn)); - result_format = std::vector(4, 0); + result_format = std::vector(4, + PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - status = traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - if (traffic_cop.GetQueuing()) { + status = traffic_cop.ExecuteHelper(state, [] { + TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(); - status = traffic_cop.p_status_; - traffic_cop.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } LOG_INFO("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); LOG_INFO("Join completed!"); - traffic_cop.CommitQueryHelper(); + traffic_cop.CommitQueryHelper(state); LOG_INFO("After Join..."); } diff --git a/test/sql/aggregate_sql_test.cpp b/test/sql/aggregate_sql_test.cpp index 62fcaddea8a..ae9624024b1 100644 --- a/test/sql/aggregate_sql_test.cpp +++ b/test/sql/aggregate_sql_test.cpp @@ -26,7 +26,6 @@ class AggregateSQLTests : public PelotonTest {}; TEST_F(AggregateSQLTests, EmptyTableTest) { PELOTON_ASSERT(&TestingSQLUtil::counter_); - PELOTON_ASSERT(&TestingSQLUtil::traffic_cop_); auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); auto txn = txn_manager.BeginTransaction(); catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn); diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index 234eabf4374..e6f59127f61 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -53,6 +53,8 @@ void TestingSQLUtil::ShowTable(std::string database_name, ExecuteSQLQuery("SELECT * FROM " + database_name + "." + table_name); } +// TODO(Tianyu): These testing code look copy-and-pasted. Should probably consider +// rewriting them. // Execute a SQL query end-to-end ResultType TestingSQLUtil::ExecuteSQLQuery( const std::string query, std::vector &result, @@ -61,40 +63,52 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( LOG_TRACE("Query: %s", query.c_str()); // prepareStatement std::string unnamed_statement = "unnamed"; + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; auto &peloton_parser = parser::PostgresParser::GetInstance(); auto sql_stmt_list = peloton_parser.BuildParseTree(query); PELOTON_ASSERT(sql_stmt_list); if (!sql_stmt_list->is_valid) { return ResultType::FAILURE; } - auto statement = traffic_cop_.PrepareStatement(unnamed_statement, query, - std::move(sql_stmt_list)); + auto statement = traffic_cop.PrepareStatement(state, + unnamed_statement, + query, + std::move(sql_stmt_list)); if (statement.get() == nullptr) { - traffic_cop_.setRowsAffected(0); + state.rows_affected_ = 0; rows_changed = 0; - error_message = traffic_cop_.GetErrorMessage(); + error_message = state.error_message_; return ResultType::FAILURE; } // ExecuteStatment std::vector param_values; - bool unnamed = false; - std::vector result_format(statement->GetTupleDescriptor().size(), 0); + std::vector + result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - auto status = traffic_cop_.ExecuteStatement(statement, param_values, unnamed, - nullptr, result_format, result); - if (traffic_cop_.GetQueuing()) { + state.statement_.reset(statement.get()); + state.param_values_ = param_values; + state.result_format_ = result_format; + state.result_ = result; + auto status = traffic_cop.ExecuteStatement(state, [] { + UtilTestTaskCallback(&counter_); + }); + if (state.is_queuing_) { ContinueAfterComplete(); - traffic_cop_.ExecuteStatementPlanGetResult(); - status = traffic_cop_.ExecuteStatementGetResult(); - traffic_cop_.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = traffic_cop.ExecuteStatementGetResult(state); + state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { tuple_descriptor = statement->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status).c_str()); - rows_changed = traffic_cop_.getRowsAffected(); + rows_changed = state.rows_affected_; + // TODO(Tianyu): This is a refactor in progress. This copy can be eliminated. + result = state.result_; return status; } @@ -107,8 +121,10 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( auto &peloton_parser = parser::PostgresParser::GetInstance(); std::vector params; auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; auto txn = txn_manager.BeginTransaction(); - traffic_cop_.SetTcopTxnState(txn); + state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); auto parsed_stmt = peloton_parser.BuildParseTree(query); @@ -117,24 +133,33 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( auto plan = optimizer->BuildPelotonPlanTree(parsed_stmt, txn); tuple_descriptor = - traffic_cop_.GenerateTupleDescriptor(parsed_stmt->GetStatement(0)); - auto result_format = std::vector(tuple_descriptor.size(), 0); + traffic_cop.GenerateTupleDescriptor(state, parsed_stmt->GetStatement(0)); + auto result_format = std::vector(tuple_descriptor.size(), + PostgresDataFormat::TEXT); try { LOG_TRACE("\n%s", planner::PlanUtil::GetInfo(plan.get()).c_str()); // SetTrafficCopCounter(); counter_.store(1); + QueryType query_type = StatementTypeToQueryType(parsed_stmt->GetStatement(0)->GetType(), + parsed_stmt->GetStatement(0)); + state.statement_ = std::make_shared("unnamed", query_type, query, std::move(parsed_stmt)); + state.param_values_ = params; + state.result_ = result; + state.result_format_ = result_format; auto status = - traffic_cop_.ExecuteHelper(plan, params, result, result_format); - if (traffic_cop_.GetQueuing()) { + traffic_cop.ExecuteHelper(state, [] { + UtilTestTaskCallback(&counter_); + }); + if (state.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop_.ExecuteStatementPlanGetResult(); - status = traffic_cop_.p_status_; - traffic_cop_.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = state.p_status_; + state.is_queuing_ = false; } rows_changed = status.m_processed; LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(status.m_result).c_str()); + ResultTypeToString(status.m_result).c_str()); return status.m_result; } catch (Exception &e) { error_message = e.what(); @@ -170,29 +195,41 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query, if (!sql_stmt_list->is_valid) { return ResultType::FAILURE; } - auto statement = traffic_cop_.PrepareStatement(unnamed_statement, query, - std::move(sql_stmt_list)); - if (statement.get() == nullptr) { - traffic_cop_.setRowsAffected(0); + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; + auto statement = traffic_cop.PrepareStatement(state, + unnamed_statement, + query, + std::move(sql_stmt_list)); + if (statement == nullptr) { + state.rows_affected_ = 0; return ResultType::FAILURE; } // ExecuteStatment std::vector param_values; - bool unnamed = false; - std::vector result_format(statement->GetTupleDescriptor().size(), 0); + std::vector + result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - auto status = traffic_cop_.ExecuteStatement(statement, param_values, unnamed, - nullptr, result_format, result); - if (traffic_cop_.GetQueuing()) { + state.statement_.reset(statement.get()); + state.param_values_ = param_values; + state.result_format_ = result_format; + state.result_ = result; + auto status = traffic_cop.ExecuteStatement(state, [] { + UtilTestTaskCallback(&counter_); + }); + if (state.is_queuing_) { ContinueAfterComplete(); - traffic_cop_.ExecuteStatementPlanGetResult(); - status = traffic_cop_.ExecuteStatementGetResult(); - traffic_cop_.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = traffic_cop.ExecuteStatementGetResult(state); + state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { tuple_descriptor = statement->GetTupleDescriptor(); } + // TODO(Tianyu) Same as above. + result = state.result_; return status; } @@ -210,31 +247,40 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query) { if (!sql_stmt_list->is_valid) { return ResultType::FAILURE; } - auto statement = traffic_cop_.PrepareStatement(unnamed_statement, query, - std::move(sql_stmt_list)); - if (statement.get() == nullptr) { - traffic_cop_.setRowsAffected(0); + auto &traffic_cop = tcop::Tcop::GetInstance(); + tcop::ClientProcessState state; + auto statement = traffic_cop.PrepareStatement(state, + unnamed_statement, + query, + std::move(sql_stmt_list)); + if (statement == nullptr) { + state.rows_affected_ = 0; return ResultType::FAILURE; } - // ExecuteStatment + // ExecuteStatement std::vector param_values; - bool unnamed = false; - std::vector result_format(statement->GetTupleDescriptor().size(), 0); + std::vector result_format(statement->GetTupleDescriptor().size(), + PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - auto status = traffic_cop_.ExecuteStatement(statement, param_values, unnamed, - nullptr, result_format, result); - if (traffic_cop_.GetQueuing()) { + state.statement_.reset(statement.get()); + state.param_values_ = param_values; + state.result_format_ = result_format; + state.result_ = result; + auto status = traffic_cop.ExecuteStatement(state, []{ + UtilTestTaskCallback(&counter_); + }); + if (state.is_queuing_) { ContinueAfterComplete(); - traffic_cop_.ExecuteStatementPlanGetResult(); - status = traffic_cop_.ExecuteStatementGetResult(); - traffic_cop_.SetQueuing(false); + traffic_cop.ExecuteStatementPlanGetResult(state); + status = traffic_cop.ExecuteStatementGetResult(state); + state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { tuple_descriptor = statement->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(status).c_str()); + ResultTypeToString(status).c_str()); return status; } @@ -313,7 +359,6 @@ void TestingSQLUtil::UtilTestTaskCallback(void *arg) { } std::atomic_int TestingSQLUtil::counter_; -tcop::TrafficCop TestingSQLUtil::traffic_cop_(UtilTestTaskCallback, &counter_); } // namespace test } // namespace peloton diff --git a/test/statistics/stats_test.cpp b/test/statistics/stats_test.cpp index ef3c7da6cba..d83af374b3d 100644 --- a/test/statistics/stats_test.cpp +++ b/test/statistics/stats_test.cpp @@ -24,7 +24,6 @@ #include "executor/insert_executor.h" #include "statistics/backend_stats_context.h" #include "statistics/stats_aggregator.h" -#include "traffic_cop/traffic_cop.h" #define NUM_ITERATION 50 #define NUM_TABLE_INSERT 1 diff --git a/test/statistics/testing_stats_util.cpp b/test/statistics/testing_stats_util.cpp index 5c087e4aba4..873b2c9d087 100644 --- a/test/statistics/testing_stats_util.cpp +++ b/test/statistics/testing_stats_util.cpp @@ -27,84 +27,87 @@ #include "planner/insert_plan.h" #include "planner/plan_util.h" #include "storage/tile.h" -#include "traffic_cop/traffic_cop.h" namespace peloton { namespace test { -void TestingStatsUtil::ShowTable(std::string database_name, - std::string table_name) { - std::unique_ptr statement; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto &peloton_parser = parser::PostgresParser::GetInstance(); - auto &traffic_cop = tcop::TrafficCop::GetInstance(); - - std::vector params; - std::vector result; - std::string sql = "SELECT * FROM " + database_name + "." + table_name; - statement.reset(new Statement("SELECT", sql)); - // using transaction to optimize - auto txn = txn_manager.BeginTransaction(); - auto select_stmt = peloton_parser.BuildParseTree(sql); - statement->SetPlanTree( - optimizer::Optimizer().BuildPelotonPlanTree(select_stmt, txn)); - LOG_DEBUG("%s", - planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - std::vector result_format(statement->GetTupleDescriptor().size(), 0); - traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, - result_format); - txn_manager.CommitTransaction(txn); -} - -storage::Tuple TestingStatsUtil::PopulateTuple(const catalog::Schema *schema, - int first_col_val, - int second_col_val, - int third_col_val, - int fourth_col_val) { - auto testing_pool = TestingHarness::GetInstance().GetTestingPool(); - storage::Tuple tuple(schema, true); - tuple.SetValue(0, type::ValueFactory::GetIntegerValue(first_col_val), - testing_pool); - - tuple.SetValue(1, type::ValueFactory::GetIntegerValue(second_col_val), - testing_pool); - - tuple.SetValue(2, type::ValueFactory::GetDecimalValue(third_col_val), - testing_pool); - - type::Value string_value = - type::ValueFactory::GetVarcharValue(std::to_string(fourth_col_val)); - tuple.SetValue(3, string_value, testing_pool); - return tuple; -} - -std::shared_ptr -TestingStatsUtil::GetQueryParams(std::shared_ptr &type_buf, - std::shared_ptr &format_buf, - std::shared_ptr &val_buf) { - // Type - uchar *type_buf_data = new uchar[1]; - type_buf_data[0] = 'x'; - type_buf.reset(type_buf_data); - stats::QueryMetric::QueryParamBuf type(type_buf_data, 1); - - // Format - uchar *format_buf_data = new uchar[1]; - format_buf_data[0] = 'y'; - format_buf.reset(format_buf_data); - stats::QueryMetric::QueryParamBuf format(format_buf_data, 1); - - // Value - uchar *val_buf_data = new uchar[1]; - val_buf_data[0] = 'z'; - val_buf.reset(val_buf_data); - stats::QueryMetric::QueryParamBuf val(val_buf_data, 1); - - // Construct a query param object - std::shared_ptr query_params( - new stats::QueryMetric::QueryParams(format, type, val, 1)); - return query_params; -} +// TODO(Tianyu): These functions are not actually called anywhere, and the way +// they are wriiten is deeply broken (ignoring async callbacks, meaning the caller +// will have to dream up a number in the test to sleep for). We should rewrite all +// this testing code. +//void TestingStatsUtil::ShowTable(std::string database_name, +// std::string table_name) { +// std::unique_ptr statement; +// auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); +// auto &peloton_parser = parser::PostgresParser::GetInstance(); +// auto &traffic_cop = tcop::TrafficCop::GetInstance(); +// +// std::vector params; +// std::vector result; +// std::string sql = "SELECT * FROM " + database_name + "." + table_name; +// statement.reset(new Statement("SELECT", sql)); +// // using transaction to optimize +// auto txn = txn_manager.BeginTransaction(); +// auto select_stmt = peloton_parser.BuildParseTree(sql); +// statement->SetPlanTree( +// optimizer::Optimizer().BuildPelotonPlanTree(select_stmt, txn)); +// LOG_DEBUG("%s", +// planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); +// std::vector result_format(statement->GetTupleDescriptor().size(), 0); +// traffic_cop.ExecuteHelper(statement->GetPlanTree(), params, result, +// result_format); +// txn_manager.CommitTransaction(txn); +//} +// +//storage::Tuple TestingStatsUtil::PopulateTuple(const catalog::Schema *schema, +// int first_col_val, +// int second_col_val, +// int third_col_val, +// int fourth_col_val) { +// auto testing_pool = TestingHarness::GetInstance().GetTestingPool(); +// storage::Tuple tuple(schema, true); +// tuple.SetValue(0, type::ValueFactory::GetIntegerValue(first_col_val), +// testing_pool); +// +// tuple.SetValue(1, type::ValueFactory::GetIntegerValue(second_col_val), +// testing_pool); +// +// tuple.SetValue(2, type::ValueFactory::GetDecimalValue(third_col_val), +// testing_pool); +// +// type::Value string_value = +// type::ValueFactory::GetVarcharValue(std::to_string(fourth_col_val)); +// tuple.SetValue(3, string_value, testing_pool); +// return tuple; +//} +// +//std::shared_ptr +//TestingStatsUtil::GetQueryParams(std::shared_ptr &type_buf, +// std::shared_ptr &format_buf, +// std::shared_ptr &val_buf) { +// // Type +// uchar *type_buf_data = new uchar[1]; +// type_buf_data[0] = 'x'; +// type_buf.reset(type_buf_data); +// stats::QueryMetric::QueryParamBuf type(type_buf_data, 1); +// +// // Format +// uchar *format_buf_data = new uchar[1]; +// format_buf_data[0] = 'y'; +// format_buf.reset(format_buf_data); +// stats::QueryMetric::QueryParamBuf format(format_buf_data, 1); +// +// // Value +// uchar *val_buf_data = new uchar[1]; +// val_buf_data[0] = 'z'; +// val_buf.reset(val_buf_data); +// stats::QueryMetric::QueryParamBuf val(val_buf_data, 1); +// +// // Construct a query param object +// std::shared_ptr query_params( +// new stats::QueryMetric::QueryParams(format, type, val, 1)); +// return query_params; +//} void TestingStatsUtil::CreateTable(bool has_primary_key) { LOG_INFO("Creating a table..."); From 6fc54836fc7b9a29bed0d33521f2d9ba5b37abc9 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 29 Jun 2018 17:30:34 -0400 Subject: [PATCH 30/48] Remove stale traffic_cop.cpp --- src/traffic_cop/traffic_cop.cpp | 620 -------------------------------- 1 file changed, 620 deletions(-) delete mode 100644 src/traffic_cop/traffic_cop.cpp diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp deleted file mode 100644 index bbf0846ac9a..00000000000 --- a/src/traffic_cop/traffic_cop.cpp +++ /dev/null @@ -1,620 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// traffic_cop.cpp -// -// Identification: src/traffic_cop/traffic_cop.cpp -// -// Copyright (c) 2015-17, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include "traffic_cop/traffic_cop.h" - -#include - -#include "binder/bind_node_visitor.h" -#include "common/internal_types.h" -#include "concurrency/transaction_context.h" -#include "concurrency/transaction_manager_factory.h" -#include "expression/expression_util.h" -#include "optimizer/optimizer.h" -#include "planner/plan_util.h" -#include "settings/settings_manager.h" -#include "threadpool/mono_queue_pool.h" - -namespace peloton { -namespace tcop { - -TrafficCop::TrafficCop() - : is_queuing_(false), - rows_affected_(0), - optimizer_(new optimizer::Optimizer()), - single_statement_txn_(true) {} - -TrafficCop::TrafficCop(void (*task_callback)(void *), void *task_callback_arg) - : optimizer_(new optimizer::Optimizer()), - single_statement_txn_(true), - task_callback_(task_callback), - task_callback_arg_(task_callback_arg) {} - -void TrafficCop::Reset() { - std::stack new_tcop_txn_state; - // clear out the stack - swap(tcop_txn_state_, new_tcop_txn_state); - optimizer_->Reset(); - results_.clear(); - param_values_.clear(); - setRowsAffected(0); -} - -TrafficCop::~TrafficCop() { - // Abort all running transactions - while (!tcop_txn_state_.empty()) { - AbortQueryHelper(); - } -} - -/* Singleton accessor - * NOTE: Used by in unit tests ONLY - */ -TrafficCop &TrafficCop::GetInstance() { - static TrafficCop tcop; - tcop.Reset(); - return tcop; -} - -TrafficCop::TcopTxnState &TrafficCop::GetDefaultTxnState() { - static TcopTxnState default_state; - default_state = std::make_pair(nullptr, ResultType::INVALID); - return default_state; -} - -TrafficCop::TcopTxnState &TrafficCop::GetCurrentTxnState() { - if (tcop_txn_state_.empty()) { - return GetDefaultTxnState(); - } - return tcop_txn_state_.top(); -} - -ResultType TrafficCop::BeginQueryHelper(size_t thread_id) { - if (tcop_txn_state_.empty()) { - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_DEBUG("Begin txn failed"); - return ResultType::FAILURE; - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - return ResultType::SUCCESS; -} - -ResultType TrafficCop::CommitQueryHelper() { - // do nothing if we have no active txns - if (tcop_txn_state_.empty()) return ResultType::NOOP; - auto &curr_state = tcop_txn_state_.top(); - tcop_txn_state_.pop(); - auto txn = curr_state.first; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // I catch the exception (ex. table not found) explicitly, - // If this exception is caused by a query in a transaction, - // I will block following queries in that transaction until 'COMMIT' or - // 'ROLLBACK' After receive 'COMMIT', see if it is rollback or really commit. - if (curr_state.second != ResultType::ABORTED) { - // txn committed - return txn_manager.CommitTransaction(txn); - } else { - // otherwise, rollback - return txn_manager.AbortTransaction(txn); - } -} - -ResultType TrafficCop::AbortQueryHelper() { - // do nothing if we have no active txns - if (tcop_txn_state_.empty()) return ResultType::NOOP; - auto &curr_state = tcop_txn_state_.top(); - tcop_txn_state_.pop(); - // explicitly abort the txn only if it has not aborted already - if (curr_state.second != ResultType::ABORTED) { - auto txn = curr_state.first; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto result = txn_manager.AbortTransaction(txn); - return result; - } else { - delete curr_state.first; - // otherwise, the txn has already been aborted - return ResultType::ABORTED; - } -} - -ResultType TrafficCop::ExecuteStatementGetResult() { - LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(p_status_.m_result).c_str()); - setRowsAffected(p_status_.m_processed); - LOG_TRACE("rows_changed %d", p_status_.m_processed); - is_queuing_ = false; - return p_status_.m_result; -} - -/* - * Execute a statement that needs a plan(so, BEGIN, COMMIT, ROLLBACK does not - * come here). - * Begin a new transaction if necessary. - * If the current transaction is already broken(for example due to previous - * invalid - * queries), directly return - * Otherwise, call ExecutePlan() - */ -executor::ExecutionResult TrafficCop::ExecuteHelper( - std::shared_ptr plan, - const std::vector ¶ms, std::vector &result, - const std::vector &result_format, size_t thread_id) { - auto &curr_state = GetCurrentTxnState(); - - concurrency::TransactionContext *txn; - if (!tcop_txn_state_.empty()) { - txn = curr_state.first; - } else { - // No active txn, single-statement txn - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // new txn, reset result status - curr_state.second = ResultType::SUCCESS; - single_statement_txn_ = true; - txn = txn_manager.BeginTransaction(thread_id); - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - - // skip if already aborted - if (curr_state.second == ResultType::ABORTED) { - // If the transaction state is ABORTED, the transaction should be aborted - // but Peloton didn't explicitly abort it yet since it didn't receive a - // COMMIT/ROLLBACK. - // Here, it receive queries other than COMMIT/ROLLBACK in an broken - // transaction, - // it should tell the client that these queries will not be executed. - p_status_.m_result = ResultType::TO_ABORT; - return p_status_; - } - - auto on_complete = [&result, this](executor::ExecutionResult p_status, - std::vector &&values) { - this->p_status_ = p_status; - // TODO (Tianyi) I would make a decision on keeping one of p_status or - // error_message in my next PR - this->error_message_ = std::move(p_status.m_error_message); - result = std::move(values); - task_callback_(task_callback_arg_); - }; - - auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete] { - executor::PlanExecutor::ExecutePlan(plan, txn, params, result_format, - on_complete); - }); - - is_queuing_ = true; - - LOG_TRACE("Check Tcop_txn_state Size After ExecuteHelper %lu", - tcop_txn_state_.size()); - return p_status_; -} - -void TrafficCop::ExecuteStatementPlanGetResult() { - if (p_status_.m_result == ResultType::FAILURE) return; - - auto txn_result = GetCurrentTxnState().first->GetResult(); - if (single_statement_txn_ || txn_result == ResultType::FAILURE) { - LOG_TRACE("About to commit/abort: single stmt: %d,txn_result: %s", - single_statement_txn_, ResultTypeToString(txn_result).c_str()); - switch (txn_result) { - case ResultType::SUCCESS: - // Commit single statement - LOG_TRACE("Commit Transaction"); - p_status_.m_result = CommitQueryHelper(); - break; - - case ResultType::FAILURE: - default: - // Abort - LOG_TRACE("Abort Transaction"); - if (single_statement_txn_) { - LOG_TRACE("Tcop_txn_state size: %lu", tcop_txn_state_.size()); - p_status_.m_result = AbortQueryHelper(); - } else { - tcop_txn_state_.top().second = ResultType::ABORTED; - p_status_.m_result = ResultType::ABORTED; - } - } - } -} - -/* - * Prepare a statement based on parse tree. Begin a transaction if necessary. - * If the query is not issued in a transaction (if txn_stack is empty and it's - * not - * BEGIN query), Peloton will create a new transation for it. single_stmt - * transaction. - * Otherwise, it's a multi_stmt transaction. - * TODO(Yuchen): We do not need a query string to prepare a statement and the - * query string may - * contain the information of multiple statements rather than the single one. - * Hack here. We store - * the query string inside Statement objects for printing infomation. - */ -std::shared_ptr TrafficCop::PrepareStatement( - const std::string &stmt_name, const std::string &query_string, - std::unique_ptr sql_stmt_list, - const size_t thread_id UNUSED_ATTRIBUTE) { - LOG_TRACE("Prepare Statement query: %s", query_string.c_str()); - - // Empty statement - // TODO (Tianyi) Read through the parser code to see if this is appropriate - if (sql_stmt_list.get() == nullptr || - sql_stmt_list->GetNumStatements() == 0) { - // TODO (Tianyi) Do we need another query type called QUERY_EMPTY? - std::shared_ptr statement = - std::make_shared(stmt_name, QueryType::QUERY_INVALID, - query_string, std::move(sql_stmt_list)); - return statement; - } - - StatementType stmt_type = sql_stmt_list->GetStatement(0)->GetType(); - QueryType query_type = - StatementTypeToQueryType(stmt_type, sql_stmt_list->GetStatement(0)); - std::shared_ptr statement = std::make_shared( - stmt_name, query_type, query_string, std::move(sql_stmt_list)); - - // We can learn transaction's states, BEGIN, COMMIT, ABORT, or ROLLBACK from - // member variables, tcop_txn_state_. We can also get single-statement txn or - // multi-statement txn from member variable single_statement_txn_ - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - // --multi-statements except BEGIN in a transaction - if (!tcop_txn_state_.empty()) { - single_statement_txn_ = false; - // multi-statment txn has been aborted, just skip this query, - // and do not need to parse or execute this query anymore. - // Do not return nullptr in case that 'COMMIT' cannot be execute, - // because nullptr will directly return ResultType::FAILURE to - // packet_manager - if (tcop_txn_state_.top().second == ResultType::ABORTED) { - return statement; - } - } else { - // Begin new transaction when received single-statement query or "BEGIN" - // from multi-statement query - if (statement->GetQueryType() == - QueryType::QUERY_BEGIN) { // only begin a new transaction - // note this transaction is not single-statement transaction - LOG_TRACE("BEGIN"); - single_statement_txn_ = false; - } else { - // single statement - LOG_TRACE("SINGLE TXN"); - single_statement_txn_ = true; - } - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_TRACE("Begin txn failed"); - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - - if (settings::SettingsManager::GetBool(settings::SettingId::brain)) { - tcop_txn_state_.top().first->AddQueryString(query_string.c_str()); - } - - // TODO(Tianyi) Move Statement Planing into Statement's method - // to increase coherence - try { - // Run binder - auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, default_database_name_); - bind_node_visitor.BindNameToNode( - statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = optimizer_->BuildPelotonPlanTree( - statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); - statement->SetPlanTree(plan); - // Get the tables that our plan references so that we know how to - // invalidate it at a later point when the catalog changes - const std::set table_oids = - planner::PlanUtil::GetTablesReferenced(plan.get()); - statement->SetReferencedTables(table_oids); - - if (query_type == QueryType::QUERY_SELECT) { - auto tuple_descriptor = GenerateTupleDescriptor( - statement->GetStmtParseTreeList()->GetStatement(0)); - statement->SetTupleDescriptor(tuple_descriptor); - LOG_TRACE("select query, finish setting"); - } - } catch (Exception &e) { - error_message_ = e.what(); - ProcessInvalidStatement(); - return nullptr; - } - -#ifdef LOG_DEBUG_ENABLED - if (statement->GetPlanTree().get() != nullptr) { - LOG_TRACE("Statement Prepared: %s", statement->GetInfo().c_str()); - LOG_TRACE("%s", statement->GetPlanTree().get()->GetInfo().c_str()); - } -#endif - return statement; -} - -/* - * Do nothing if there is no active transaction; - * If single-stmt transaction, abort it; - * If multi-stmt transaction, just set transaction state to 'ABORTED'. - * The multi-stmt txn will be explicitly aborted when receiving 'Commit' or - * 'Rollback'. - */ -void TrafficCop::ProcessInvalidStatement() { - if (single_statement_txn_) { - LOG_TRACE("SINGLE ABORT!"); - AbortQueryHelper(); - } else { // multi-statment txn - if (tcop_txn_state_.top().second != ResultType::ABORTED) { - tcop_txn_state_.top().second = ResultType::ABORTED; - } - } -} - -bool TrafficCop::BindParamsForCachePlan( - const std::vector> - ¶meters, - const size_t thread_id UNUSED_ATTRIBUTE) { - if (tcop_txn_state_.empty()) { - single_statement_txn_ = true; - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); - auto txn = txn_manager.BeginTransaction(thread_id); - // this shouldn't happen - if (txn == nullptr) { - LOG_ERROR("Begin txn failed"); - } - // initialize the current result as success - tcop_txn_state_.emplace(txn, ResultType::SUCCESS); - } - // Run binder - auto bind_node_visitor = binder::BindNodeVisitor(tcop_txn_state_.top().first, - default_database_name_); - - std::vector param_values; - for (const std::unique_ptr ¶m : - parameters) { - if (!expression::ExpressionUtil::IsValidStaticExpression(param.get())) { - error_message_ = "Invalid Expression Type"; - return false; - } - param->Accept(&bind_node_visitor); - // TODO(Yuchen): need better check for nullptr argument - param_values.push_back(param->Evaluate(nullptr, nullptr, nullptr)); - } - if (param_values.size() > 0) { - statement_->GetPlanTree()->SetParameterValues(¶m_values); - } - SetParamVal(param_values); - return true; -} - -void TrafficCop::GetTableColumns(parser::TableRef *from_table, - std::vector &target_columns) { - if (from_table == nullptr) return; - - // Query derived table - if (from_table->select != NULL) { - for (auto &expr : from_table->select->select_list) { - if (expr->GetExpressionType() == ExpressionType::STAR) - GetTableColumns(from_table->select->from_table.get(), target_columns); - else - target_columns.push_back(catalog::Column(expr->GetValueType(), 0, - expr->GetExpressionName())); - } - } else if (from_table->list.empty()) { - if (from_table->join == NULL) { - auto columns = - static_cast( - catalog::Catalog::GetInstance()->GetTableWithName( - GetCurrentTxnState().first, - from_table->GetDatabaseName(), - from_table->GetSchemaName(), - from_table->GetTableName())) - ->GetSchema() - ->GetColumns(); - target_columns.insert(target_columns.end(), columns.begin(), - columns.end()); - } else { - GetTableColumns(from_table->join->left.get(), target_columns); - GetTableColumns(from_table->join->right.get(), target_columns); - } - } - // Query has multiple tables. Recursively add all tables - else { - for (auto &table : from_table->list) { - GetTableColumns(table.get(), target_columns); - } - } -} - -std::vector TrafficCop::GenerateTupleDescriptor( - parser::SQLStatement *sql_stmt) { - std::vector tuple_descriptor; - if (sql_stmt->GetType() != StatementType::SELECT) return tuple_descriptor; - auto select_stmt = (parser::SelectStatement *)sql_stmt; - - // TODO: this is a hack which I don't have time to fix now - // but it replaces a worse hack that was here before - // What should happen here is that plan nodes should store - // the schema of their expected results and here we should just read - // it and put it in the tuple descriptor - - // Get the columns information and set up - // the columns description for the returned results - // Set up the table - std::vector all_columns; - - // Check if query only has one Table - // Example : SELECT * FROM A; - GetTableColumns(select_stmt->from_table.get(), all_columns); - - int count = 0; - for (auto &expr : select_stmt->select_list) { - count++; - if (expr->GetExpressionType() == ExpressionType::STAR) { - for (auto column : all_columns) { - tuple_descriptor.push_back( - GetColumnFieldForValueType(column.GetName(), column.GetType())); - } - } else { - std::string col_name; - if (expr->alias.empty()) { - col_name = expr->expr_name_.empty() - ? std::string("expr") + std::to_string(count) - : expr->expr_name_; - } else { - col_name = expr->alias; - } - tuple_descriptor.push_back( - GetColumnFieldForValueType(col_name, expr->GetValueType())); - } - } - - return tuple_descriptor; -} - -// TODO: move it to postgres_protocal_handler.cpp -FieldInfo TrafficCop::GetColumnFieldForValueType(std::string column_name, - type::TypeId column_type) { - PostgresValueType field_type; - size_t field_size; - switch (column_type) { - case type::TypeId::BOOLEAN: - case type::TypeId::TINYINT: { - field_type = PostgresValueType::BOOLEAN; - field_size = 1; - break; - } - case type::TypeId::SMALLINT: { - field_type = PostgresValueType::SMALLINT; - field_size = 2; - break; - } - case type::TypeId::INTEGER: { - field_type = PostgresValueType::INTEGER; - field_size = 4; - break; - } - case type::TypeId::BIGINT: { - field_type = PostgresValueType::BIGINT; - field_size = 8; - break; - } - case type::TypeId::DECIMAL: { - field_type = PostgresValueType::DOUBLE; - field_size = 8; - break; - } - case type::TypeId::VARCHAR: - case type::TypeId::VARBINARY: { - field_type = PostgresValueType::TEXT; - field_size = 255; - break; - } - case type::TypeId::DATE: { - field_type = PostgresValueType::DATE; - field_size = 4; - break; - } - case type::TypeId::TIMESTAMP: { - field_type = PostgresValueType::TIMESTAMPS; - field_size = 64; // FIXME: Bytes??? - break; - } - default: { - // Type not Identified - LOG_ERROR("Unrecognized field type '%s' for field '%s'", - TypeIdToString(column_type).c_str(), column_name.c_str()); - field_type = PostgresValueType::TEXT; - field_size = 255; - break; - } - } - // HACK: Convert the type into a oid_t - // This ugly and I don't like it one bit... - return std::make_tuple(column_name, static_cast(field_type), - field_size); -} - -ResultType TrafficCop::ExecuteStatement( - const std::shared_ptr &statement, - const std::vector ¶ms, UNUSED_ATTRIBUTE bool unnamed, - std::shared_ptr param_stats, - const std::vector &result_format, std::vector &result, - size_t thread_id) { - // TODO(Tianyi) Further simplify this API - if (static_cast(settings::SettingsManager::GetInt( - settings::SettingId::stats_mode)) != StatsType::INVALID) { - stats::BackendStatsContext::GetInstance()->InitQueryMetric( - statement, std::move(param_stats)); - } - - LOG_TRACE("Execute Statement of name: %s", - statement->GetStatementName().c_str()); - LOG_TRACE("Execute Statement of query: %s", - statement->GetQueryString().c_str()); - LOG_TRACE("Execute Statement Plan:\n%s", - planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - LOG_TRACE("Execute Statement Query Type: %s", - statement->GetQueryTypeString().c_str()); - LOG_TRACE("----QueryType: %d--------", - static_cast(statement->GetQueryType())); - - try { - switch (statement->GetQueryType()) { - case QueryType::QUERY_BEGIN: { - return BeginQueryHelper(thread_id); - } - case QueryType::QUERY_COMMIT: { - return CommitQueryHelper(); - } - case QueryType::QUERY_ROLLBACK: { - return AbortQueryHelper(); - } - default: - // The statement may be out of date - // It needs to be replan - if (statement->GetNeedsReplan()) { - // TODO(Tianyi) Move Statement Replan into Statement's method - // to increase coherence - auto bind_node_visitor = binder::BindNodeVisitor( - tcop_txn_state_.top().first, default_database_name_); - bind_node_visitor.BindNameToNode( - statement->GetStmtParseTreeList()->GetStatement(0)); - auto plan = optimizer_->BuildPelotonPlanTree( - statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); - statement->SetPlanTree(plan); - statement->SetNeedsReplan(true); - } - - ExecuteHelper(statement->GetPlanTree(), params, result, result_format, - thread_id); - if (GetQueuing()) { - return ResultType::QUEUING; - } else { - return ExecuteStatementGetResult(); - } - } - - } catch (Exception &e) { - error_message_ = e.what(); - return ResultType::FAILURE; - } -} - -} // namespace tcop -} // namespace peloton From 169338b213b098a35acb5bd60dcc732d9097381d Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Sat, 30 Jun 2018 22:40:32 -0400 Subject: [PATCH 31/48] Fix some memory issues in unit tests. --- src/include/network/postgres_protocol_interpreter.h | 5 ++++- src/traffic_cop/tcop.cpp | 2 +- test/sql/testing_sql_util.cpp | 8 ++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 15e197351cc..1d2587d5933 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -32,7 +32,9 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { std::shared_ptr out, CallbackFunc callback) override; - inline void GetResult() override {} + inline void GetResult() override { + + } inline void AddCmdlineOption(const std::string &key, std::string value) { cmdline_options_[key] = std::move(value); @@ -50,6 +52,7 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { void ExecExecuteMessageGetResult(PostgresPacketWriter &out, ResultType status); ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); + std::unordered_map> portals_; private: bool startup_ = true; diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index 5eabc4b9bb2..c2212f2eca7 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -499,7 +499,7 @@ executor::ExecutionResult Tcop::ExecuteHelper(ClientProcessState &state, formats.push_back((int) format); auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([on_complete, &state, &txn, &formats] { + pool.SubmitTask([on_complete, txn, formats, &state] { executor::PlanExecutor::ExecutePlan(state.statement_->GetPlanTree(), txn, state.param_values_, diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index e6f59127f61..def7936ad3e 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -88,7 +88,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - state.statement_.reset(statement.get()); + statement.swap(state.statement_); state.param_values_ = param_values; state.result_format_ = result_format; state.result_ = result; @@ -102,7 +102,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = statement->GetTupleDescriptor(); + tuple_descriptor = state.statement_->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status).c_str()); @@ -263,7 +263,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query) { PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - state.statement_.reset(statement.get()); + statement.swap(state.statement_); state.param_values_ = param_values; state.result_format_ = result_format; state.result_ = result; @@ -277,7 +277,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query) { state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = statement->GetTupleDescriptor(); + tuple_descriptor = state.statement_->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status).c_str()); From 763b5b93fd1dd91f175e3b69ea2df01d4fe4b351 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Sun, 1 Jul 2018 18:13:43 -0400 Subject: [PATCH 32/48] Fix some tests --- src/include/network/network_io_utils.h | 2 ++ .../network/postgres_protocol_interpreter.h | 19 +++++++++++++++++-- src/include/network/postgres_protocol_utils.h | 3 ++- src/include/network/protocol_interpreter.h | 2 +- src/network/connection_handle.cpp | 2 +- src/network/postgres_network_commands.cpp | 14 ++++++++++---- src/network/postgres_protocol_interpreter.cpp | 6 ++++-- test/sql/testing_sql_util.cpp | 7 ++++--- 8 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index c6581298215..fb377045ab4 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -45,6 +45,8 @@ class Buffer { offset_ = 0; } + inline void Skip(size_t bytes) { offset_ += bytes; } + /** * @param bytes The amount of bytes to check between the cursor and the end * of the buffer (defaults to any) diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index 1d2587d5933..c2b73fd3bc2 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -32,7 +32,22 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { std::shared_ptr out, CallbackFunc callback) override; - inline void GetResult() override { + inline void GetResult(std::shared_ptr out) override { + + auto tcop = tcop::Tcop::GetInstance(); + // TODO(Tianyu): The difference between these two methods are unclear to me + tcop.ExecuteStatementPlanGetResult(state_); + auto status = tcop.ExecuteStatementGetResult(state_); + PostgresPacketWriter writer(*out); + switch (protocol_type_) { + case NetworkProtocolType::POSTGRES_JDBC: + LOG_TRACE("JDBC result"); + ExecExecuteMessageGetResult(writer, status); + break; + case NetworkProtocolType::POSTGRES_PSQL: + LOG_TRACE("PSQL result"); + ExecQueryMessageGetResult(writer, status); + } } @@ -52,7 +67,7 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { void ExecExecuteMessageGetResult(PostgresPacketWriter &out, ResultType status); ResultType ExecQueryExplain(const std::string &query, parser::ExplainStatement &explain_stmt); - + NetworkProtocolType protocol_type_; std::unordered_map> portals_; private: bool startup_ = true; diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index cda1216df97..6772a8b7414 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -189,6 +189,7 @@ class PostgresPacketWriter { inline void WriteStartupResponse() { BeginPacket(NetworkMessageType::AUTHENTICATION_REQUEST) + .AppendValue(0) .EndPacket(); for (auto &entry : parameter_status_map) @@ -238,7 +239,7 @@ class PostgresPacketWriter { AppendValue(NULL_CONTENT_SIZE); else AppendValue(content.size()) - .AppendString(content); + .AppendString(content, false); } EndPacket(); diff --git a/src/include/network/protocol_interpreter.h b/src/include/network/protocol_interpreter.h index 4b2f5bafdf9..d1eaf36442c 100644 --- a/src/include/network/protocol_interpreter.h +++ b/src/include/network/protocol_interpreter.h @@ -25,7 +25,7 @@ class ProtocolInterpreter { CallbackFunc callback) = 0; // TODO(Tianyu): Do we really need this crap? - virtual void GetResult() = 0; + virtual void GetResult(std::shared_ptr out) = 0; }; } // namespace network diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index 1b453ac702d..5bb9d27d913 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -172,7 +172,7 @@ ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) Transition ConnectionHandle::GetResult() { EventUtil::EventAdd(network_event_, nullptr); - protocol_interpreter_->GetResult(); + protocol_interpreter_->GetResult(io_wrapper_->GetWriteQueue()); return Transition::PROCEED; } diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 80aff5406d5..2523d0c324a 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -72,7 +72,8 @@ void PostgresNetworkCommand::ReadParamValues(std::vector &bind_pa param_types[i], param_len); break; - default:throw NetworkProcessException("Unexpected format code"); + default: + throw NetworkProcessException("Unexpected format code"); } } } @@ -195,7 +196,9 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::TERMINATE; } - while (in_->HasMore()) { + // The last bit of the packet will be nul. This is not a valid field. When there + // is less than 2 bytes of data remaining we can already exit early. + while (in_->HasMore(2)) { // TODO(Tianyu): We don't seem to really handle the other flags? std::string key = in_->ReadString(), value = in_->ReadString(); LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); @@ -203,7 +206,8 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, state.db_name_ = value; interpreter.AddCmdlineOption(key, std::move(value)); } - + // skip the last nul byte + in_->Skip(1); // TODO(Tianyu): Implement authentication. For now we always send AuthOK out.WriteStartupResponse(); interpreter.FinishStartup(); @@ -213,6 +217,7 @@ Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc callback) { + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string query = in_->ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); @@ -312,7 +317,7 @@ Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } default: { - std::string stmt_name = "unamed"; + std::string stmt_name = "unnamed"; std::unique_ptr unnamed_sql_stmt_list( new parser::SQLStatementList()); unnamed_sql_stmt_list->PassInStatement(std::move(sql_stmt)); @@ -517,6 +522,7 @@ Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, Transition ExecuteCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc callback) { + interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_JDBC; tcop::ClientProcessState &state = interpreter.ClientProcessState(); std::string portal_name = in_->ReadString(); diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index 96938f628fc..d1439fa0926 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -130,7 +130,8 @@ void PostgresProtocolInterpreter::CompleteCommand(PostgresPacketWriter &out, tag += " " + std::to_string(rows); } out.BeginPacket(NetworkMessageType::COMMAND_COMPLETE) - .AppendString(tag); + .AppendString(tag) + .EndPacket(); } void PostgresProtocolInterpreter::ExecQueryMessageGetResult(PostgresPacketWriter &out, @@ -157,7 +158,8 @@ void PostgresProtocolInterpreter::ExecQueryMessageGetResult(PostgresPacketWriter out.WriteTupleDescriptor(tuple_descriptor); out.WriteDataRows(state_.result_, tuple_descriptor.size()); // TODO(Tianyu): WTF? - state_.rows_affected_ = state_.result_.size() / tuple_descriptor.size(); + if (!tuple_descriptor.empty()) + state_.rows_affected_ = state_.result_.size() / tuple_descriptor.size(); CompleteCommand(out, state_.statement_->GetQueryType(), diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index def7936ad3e..49c53b9acaf 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -144,8 +144,8 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( QueryType query_type = StatementTypeToQueryType(parsed_stmt->GetStatement(0)->GetType(), parsed_stmt->GetStatement(0)); state.statement_ = std::make_shared("unnamed", query_type, query, std::move(parsed_stmt)); + state.statement_->SetPlanTree(plan); state.param_values_ = params; - state.result_ = result; state.result_format_ = result_format; auto status = traffic_cop.ExecuteHelper(state, [] { @@ -158,6 +158,7 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( state.is_queuing_ = false; } rows_changed = status.m_processed; + result = state.result_; LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); return status.m_result; @@ -212,7 +213,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query, PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - state.statement_.reset(statement.get()); + statement.swap(state.statement_); state.param_values_ = param_values; state.result_format_ = result_format; state.result_ = result; @@ -226,7 +227,7 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query, state.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = statement->GetTupleDescriptor(); + tuple_descriptor = state.statement_->GetTupleDescriptor(); } // TODO(Tianyu) Same as above. result = state.result_; From 685b8bc7be5f1d84228d6ec56122e93a4e9cbbed Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Sun, 1 Jul 2018 21:23:09 -0400 Subject: [PATCH 33/48] fix more tests --- src/common/internal_types.cpp | 3 +++ test/optimizer/optimizer_test.cpp | 11 +++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/common/internal_types.cpp b/src/common/internal_types.cpp index b6e52105ae6..855f7ef2d9b 100644 --- a/src/common/internal_types.cpp +++ b/src/common/internal_types.cpp @@ -2030,6 +2030,9 @@ std::string ResultTypeToString(ResultType type) { case ResultType::UNKNOWN: { return ("UNKNOWN"); } + case ResultType::QUEUING: { + return ("QUEUING"); + } case ResultType::TO_ABORT: { return ("TO_ABORT"); } diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 0a3ffcedb4d..f0aaeaa80cb 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -105,7 +105,7 @@ TEST_F(OptimizerTests, HashJoinTest) { std::vector result_format(statement->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - state.statement_.reset(statement.get()); + state.statement_.swap(statement); state.param_values_ = params; state.result_format_ = result_format; executor::ExecutionResult status = traffic_cop.ExecuteHelper(state, [] { @@ -149,7 +149,7 @@ TEST_F(OptimizerTests, HashJoinTest) { result_format = std::vector(statement->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - state.statement_ = statement; + state.statement_.swap(statement); state.param_values_ = params; state.result_format_ = result_format; status = traffic_cop.ExecuteHelper(state, [] { @@ -193,7 +193,7 @@ TEST_F(OptimizerTests, HashJoinTest) { result_format = std::vector(statement->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - state.statement_ = statement; + state.statement_.swap(statement); state.param_values_ = params; state.result_format_ = result_format; status = traffic_cop.ExecuteHelper(state, [] { @@ -229,7 +229,7 @@ TEST_F(OptimizerTests, HashJoinTest) { result_format = std::vector(statement->GetTupleDescriptor().size(), PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); - state.statement_ = statement; + state.statement_.swap(statement); state.param_values_ = params; state.result_format_ = result_format; status = traffic_cop.ExecuteHelper(state, [] { @@ -264,6 +264,9 @@ TEST_F(OptimizerTests, HashJoinTest) { result_format = std::vector(4, PostgresDataFormat::TEXT); TestingSQLUtil::counter_.store(1); + state.statement_.swap(statement); + state.param_values_ = params; + state.result_format_ = result_format; status = traffic_cop.ExecuteHelper(state, [] { TestingSQLUtil::UtilTestTaskCallback(&TestingSQLUtil::counter_); }); From fa583f25eeb7a1d1a1868512d7f604a2292875d5 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Mon, 2 Jul 2018 11:27:38 -0400 Subject: [PATCH 34/48] Fix Compile failure by importing library and using different function of endian conversion. --- src/include/network/network_io_utils.h | 3 ++- src/include/network/postgres_protocol_utils.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index fb377045ab4..2d3437062bd 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "common/internal_types.h" #include "common/exception.h" @@ -183,7 +184,7 @@ class ReadBuffer : public Buffer { case 4: return _CAST(T, ntohl(_CAST(uint32_t, val))); case 8: - return _CAST(T, ntohll(_CAST(uint64_t, val))); + return _CAST(T, be64toh(_CAST(uint64_t, val))); // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index 6772a8b7414..8d546894462 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -151,7 +151,7 @@ class PostgresPacketWriter { case 1: return AppendRawValue(val); case 2: return AppendRawValue(_CAST(T, ntohs(_CAST(uint16_t, val)))); case 4: return AppendRawValue(_CAST(T, ntohl(_CAST(uint32_t, val)))); - case 8: return AppendRawValue(_CAST(T, ntohll(_CAST(uint64_t, val)))); + case 8: return AppendRawValue(_CAST(T, be64toh(_CAST(uint64_t, val)))); // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } From dc96b75f8d480509a07594558d9753816a5e2812 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Mon, 2 Jul 2018 11:28:04 -0400 Subject: [PATCH 35/48] One line fix for initial value change. --- src/include/traffic_cop/tcop.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index 12a94bd6d58..a9e712c0312 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -38,7 +38,7 @@ struct ClientProcessState { // The optimizer used for this connection std::unique_ptr optimizer_{new optimizer::Optimizer()}; // flag of single statement txn - bool single_statement_txn_ = false; + bool single_statement_txn_ = true; std::vector result_format_; // flag of single statement txn std::vector result_; From c0a7fe419ce2d51f6d09bdecd137836ff5071052 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Mon, 2 Jul 2018 15:03:47 -0400 Subject: [PATCH 36/48] Convert byte swaps to portable ones --- src/include/network/network_io_utils.h | 5 +++-- src/include/network/postgres_protocol_utils.h | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index 2d3437062bd..9c1da5e88b1 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -16,6 +16,7 @@ #include #include #include +#include "util/portable_endian.h" #include "common/internal_types.h" #include "common/exception.h" @@ -180,9 +181,9 @@ class ReadBuffer : public Buffer { switch (sizeof(T)) { case 1: return val; case 2: - return _CAST(T, ntohs(_CAST(uint16_t, val))); + return _CAST(T, be16toh(_CAST(uint16_t, val))); case 4: - return _CAST(T, ntohl(_CAST(uint32_t, val))); + return _CAST(T, be32toh(_CAST(uint32_t, val))); case 8: return _CAST(T, be64toh(_CAST(uint64_t, val))); // Will never be here due to compiler optimization diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index 8d546894462..22fa47b4870 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -149,9 +149,9 @@ class PostgresPacketWriter { switch (sizeof(T)) { case 1: return AppendRawValue(val); - case 2: return AppendRawValue(_CAST(T, ntohs(_CAST(uint16_t, val)))); - case 4: return AppendRawValue(_CAST(T, ntohl(_CAST(uint32_t, val)))); - case 8: return AppendRawValue(_CAST(T, be64toh(_CAST(uint64_t, val)))); + case 2: return AppendRawValue(_CAST(T, htobe16(_CAST(uint16_t, val)))); + case 4: return AppendRawValue(_CAST(T, htobe32(_CAST(uint32_t, val)))); + case 8: return AppendRawValue(_CAST(T, htobe64(_CAST(uint64_t, val)))); // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } From bf8ebb1b9f1b94e87c632f96e7affa1109da12c8 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Mon, 2 Jul 2018 16:41:41 -0400 Subject: [PATCH 37/48] fix catalog test --- test/include/sql/testing_sql_util.h | 1 + test/sql/testing_sql_util.cpp | 110 ++++++++++++++-------------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/test/include/sql/testing_sql_util.h b/test/include/sql/testing_sql_util.h index 87f2a97e0a8..13f3d2834ad 100644 --- a/test/include/sql/testing_sql_util.h +++ b/test/include/sql/testing_sql_util.h @@ -95,6 +95,7 @@ class TestingSQLUtil { static int GetRandomInteger(const int lower_bound, const int upper_bound); static void UtilTestTaskCallback(void *arg); + static tcop::ClientProcessState state_; static std::atomic_int counter_; // inline static void SetTrafficCopCounter() { // counter_.store(1); diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index 49c53b9acaf..1cfedf63d4c 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -36,6 +36,8 @@ namespace test { std::random_device rd; std::mt19937 rng(rd()); +tcop::ClientProcessState TestingSQLUtil::state_; + // Create a uniform random number int TestingSQLUtil::GetRandomInteger(const int lower_bound, const int upper_bound) { @@ -64,21 +66,20 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( // prepareStatement std::string unnamed_statement = "unnamed"; auto &traffic_cop = tcop::Tcop::GetInstance(); - tcop::ClientProcessState state; auto &peloton_parser = parser::PostgresParser::GetInstance(); auto sql_stmt_list = peloton_parser.BuildParseTree(query); PELOTON_ASSERT(sql_stmt_list); if (!sql_stmt_list->is_valid) { return ResultType::FAILURE; } - auto statement = traffic_cop.PrepareStatement(state, + auto statement = traffic_cop.PrepareStatement(state_, unnamed_statement, query, std::move(sql_stmt_list)); if (statement.get() == nullptr) { - state.rows_affected_ = 0; + state_.rows_affected_ = 0; rows_changed = 0; - error_message = state.error_message_; + error_message = state_.error_message_; return ResultType::FAILURE; } // ExecuteStatment @@ -88,27 +89,27 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - statement.swap(state.statement_); - state.param_values_ = param_values; - state.result_format_ = result_format; - state.result_ = result; - auto status = traffic_cop.ExecuteStatement(state, [] { + statement.swap(state_.statement_); + state_.param_values_ = param_values; + state_.result_format_ = result_format; + state_.result_ = result; + auto status = traffic_cop.ExecuteStatement(state_, [] { UtilTestTaskCallback(&counter_); }); - if (state.is_queuing_) { + if (state_.is_queuing_) { ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(state); - status = traffic_cop.ExecuteStatementGetResult(state); - state.is_queuing_ = false; + traffic_cop.ExecuteStatementPlanGetResult(state_); + status = traffic_cop.ExecuteStatementGetResult(state_); + state_.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = state.statement_->GetTupleDescriptor(); + tuple_descriptor = state_.statement_->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status).c_str()); - rows_changed = state.rows_affected_; + rows_changed = state_.rows_affected_; // TODO(Tianyu): This is a refactor in progress. This copy can be eliminated. - result = state.result_; + result = state_.result_; return status; } @@ -122,9 +123,8 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( std::vector params; auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); auto &traffic_cop = tcop::Tcop::GetInstance(); - tcop::ClientProcessState state; auto txn = txn_manager.BeginTransaction(); - state.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); + state_.tcop_txn_state_.emplace(txn, ResultType::SUCCESS); auto parsed_stmt = peloton_parser.BuildParseTree(query); @@ -133,7 +133,7 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( auto plan = optimizer->BuildPelotonPlanTree(parsed_stmt, txn); tuple_descriptor = - traffic_cop.GenerateTupleDescriptor(state, parsed_stmt->GetStatement(0)); + traffic_cop.GenerateTupleDescriptor(state_, parsed_stmt->GetStatement(0)); auto result_format = std::vector(tuple_descriptor.size(), PostgresDataFormat::TEXT); @@ -143,22 +143,22 @@ ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( counter_.store(1); QueryType query_type = StatementTypeToQueryType(parsed_stmt->GetStatement(0)->GetType(), parsed_stmt->GetStatement(0)); - state.statement_ = std::make_shared("unnamed", query_type, query, std::move(parsed_stmt)); - state.statement_->SetPlanTree(plan); - state.param_values_ = params; - state.result_format_ = result_format; + state_.statement_ = std::make_shared("unnamed", query_type, query, std::move(parsed_stmt)); + state_.statement_->SetPlanTree(plan); + state_.param_values_ = params; + state_.result_format_ = result_format; auto status = - traffic_cop.ExecuteHelper(state, [] { + traffic_cop.ExecuteHelper(state_, [] { UtilTestTaskCallback(&counter_); }); - if (state.is_queuing_) { + if (state_.is_queuing_) { TestingSQLUtil::ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(state); - status = state.p_status_; - state.is_queuing_ = false; + traffic_cop.ExecuteStatementPlanGetResult(state_); + status = state_.p_status_; + state_.is_queuing_ = false; } rows_changed = status.m_processed; - result = state.result_; + result = state_.result_; LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status.m_result).c_str()); return status.m_result; @@ -197,13 +197,12 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query, return ResultType::FAILURE; } auto &traffic_cop = tcop::Tcop::GetInstance(); - tcop::ClientProcessState state; - auto statement = traffic_cop.PrepareStatement(state, + auto statement = traffic_cop.PrepareStatement(state_, unnamed_statement, query, std::move(sql_stmt_list)); if (statement == nullptr) { - state.rows_affected_ = 0; + state_.rows_affected_ = 0; return ResultType::FAILURE; } // ExecuteStatment @@ -213,24 +212,24 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query, PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - statement.swap(state.statement_); - state.param_values_ = param_values; - state.result_format_ = result_format; - state.result_ = result; - auto status = traffic_cop.ExecuteStatement(state, [] { + statement.swap(state_.statement_); + state_.param_values_ = param_values; + state_.result_format_ = result_format; + state_.result_ = result; + auto status = traffic_cop.ExecuteStatement(state_, [] { UtilTestTaskCallback(&counter_); }); - if (state.is_queuing_) { + if (state_.is_queuing_) { ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(state); - status = traffic_cop.ExecuteStatementGetResult(state); - state.is_queuing_ = false; + traffic_cop.ExecuteStatementPlanGetResult(state_); + status = traffic_cop.ExecuteStatementGetResult(state_); + state_.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = state.statement_->GetTupleDescriptor(); + tuple_descriptor = state_.statement_->GetTupleDescriptor(); } // TODO(Tianyu) Same as above. - result = state.result_; + result = state_.result_; return status; } @@ -249,13 +248,12 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query) { return ResultType::FAILURE; } auto &traffic_cop = tcop::Tcop::GetInstance(); - tcop::ClientProcessState state; - auto statement = traffic_cop.PrepareStatement(state, + auto statement = traffic_cop.PrepareStatement(state_, unnamed_statement, query, std::move(sql_stmt_list)); if (statement == nullptr) { - state.rows_affected_ = 0; + state_.rows_affected_ = 0; return ResultType::FAILURE; } // ExecuteStatement @@ -264,21 +262,21 @@ ResultType TestingSQLUtil::ExecuteSQLQuery(const std::string query) { PostgresDataFormat::TEXT); // SetTrafficCopCounter(); counter_.store(1); - statement.swap(state.statement_); - state.param_values_ = param_values; - state.result_format_ = result_format; - state.result_ = result; - auto status = traffic_cop.ExecuteStatement(state, []{ + statement.swap(state_.statement_); + state_.param_values_ = param_values; + state_.result_format_ = result_format; + state_.result_ = result; + auto status = traffic_cop.ExecuteStatement(state_, []{ UtilTestTaskCallback(&counter_); }); - if (state.is_queuing_) { + if (state_.is_queuing_) { ContinueAfterComplete(); - traffic_cop.ExecuteStatementPlanGetResult(state); - status = traffic_cop.ExecuteStatementGetResult(state); - state.is_queuing_ = false; + traffic_cop.ExecuteStatementPlanGetResult(state_); + status = traffic_cop.ExecuteStatementGetResult(state_); + state_.is_queuing_ = false; } if (status == ResultType::SUCCESS) { - tuple_descriptor = state.statement_->GetTupleDescriptor(); + tuple_descriptor = state_.statement_->GetTupleDescriptor(); } LOG_TRACE("Statement executed. Result: %s", ResultTypeToString(status).c_str()); From 0e253c256fc229ed919a872a4e21f31689070bf9 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Mon, 2 Jul 2018 21:01:26 -0400 Subject: [PATCH 38/48] Remove stale enum class --- src/include/network/network_types.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/include/network/network_types.h b/src/include/network/network_types.h index 372f626a51a..99f57ec46e3 100644 --- a/src/include/network/network_types.h +++ b/src/include/network/network_types.h @@ -42,13 +42,5 @@ enum class Transition { NEED_WRITE }; -enum class ResponseProtocol { - // No response required (for intermediate messgaes such as parse, bind, etc.) - NO, - // PSQL - SIMPLE, - // JDBC, PQXX, etc. - EXTENDED -}; } // namespace network } // namespace peloton From f8d32c45f32ca5bc6a301ab55301c1fb6987821b Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Tue, 3 Jul 2018 01:54:05 -0400 Subject: [PATCH 39/48] Fix zero denominator problem of row affected calculation --- src/network/postgres_protocol_interpreter.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index d1439fa0926..22ed4bca839 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -202,7 +202,8 @@ void PostgresProtocolInterpreter::ExecExecuteMessageGetResult(PostgresPacketWrit auto tuple_descriptor = state_.statement_->GetTupleDescriptor(); out.WriteDataRows(state_.result_, tuple_descriptor.size()); - state_.rows_affected_ = state_.result_.size() / tuple_descriptor.size(); + state_.rows_affected_ = tuple_descriptor.size() == 0 ? + 0 : (state_.result_.size() / tuple_descriptor.size()); CompleteCommand(out, query_type, state_.rows_affected_); return; } From e72ed4edde03c99e04f9f55d6d2874e520c6b252 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Tue, 3 Jul 2018 18:45:34 -0400 Subject: [PATCH 40/48] Save work --- .../testing/{junit => }/InsertTPCCTest.java | 0 script/testing/{junit => }/InsertTest.java | 0 script/testing/{junit => }/PLTestBase.java | 0 script/testing/{junit => }/UpdateTest.java | 0 script/testing/junit/run_junit.py | 6 +- src/include/common/internal_types.h | 11 - src/include/network/network_io_utils.h | 204 +++++++++++------- .../network/postgres_network_commands.h | 13 +- .../network/postgres_protocol_interpreter.h | 8 +- src/include/network/postgres_protocol_utils.h | 18 +- src/network/network_io_wrappers.cpp | 19 +- src/network/postgres_network_commands.cpp | 104 +++------ src/network/postgres_protocol_interpreter.cpp | 56 ++++- 13 files changed, 243 insertions(+), 196 deletions(-) rename script/testing/{junit => }/InsertTPCCTest.java (100%) rename script/testing/{junit => }/InsertTest.java (100%) rename script/testing/{junit => }/PLTestBase.java (100%) rename script/testing/{junit => }/UpdateTest.java (100%) diff --git a/script/testing/junit/InsertTPCCTest.java b/script/testing/InsertTPCCTest.java similarity index 100% rename from script/testing/junit/InsertTPCCTest.java rename to script/testing/InsertTPCCTest.java diff --git a/script/testing/junit/InsertTest.java b/script/testing/InsertTest.java similarity index 100% rename from script/testing/junit/InsertTest.java rename to script/testing/InsertTest.java diff --git a/script/testing/junit/PLTestBase.java b/script/testing/PLTestBase.java similarity index 100% rename from script/testing/junit/PLTestBase.java rename to script/testing/PLTestBase.java diff --git a/script/testing/junit/UpdateTest.java b/script/testing/UpdateTest.java similarity index 100% rename from script/testing/junit/UpdateTest.java rename to script/testing/UpdateTest.java diff --git a/script/testing/junit/run_junit.py b/script/testing/junit/run_junit.py index b628cfe2d6b..fd03fd74cb9 100755 --- a/script/testing/junit/run_junit.py +++ b/script/testing/junit/run_junit.py @@ -101,12 +101,12 @@ def _run_junit(self): def run(self): """ Orchestrate the overall JUnit test execution """ - self._check_peloton_binary() - self._run_peloton() + # self._check_peloton_binary() + # self._run_peloton() ret_val = self._run_junit() self._print_output(self.junit_output_file) - self._stop_peloton() + # self._stop_peloton() if ret_val: # print the peloton log file, only if we had a failure self._print_output(self.peloton_output_file) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index f0d46447e12..03edefcb347 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1427,17 +1427,6 @@ typedef unsigned char uchar; /* type for buffer of bytes */ typedef std::vector ByteBuf; -//===--------------------------------------------------------------------===// -// Packet Manager: ProcessResult -//===--------------------------------------------------------------------===// -enum class ProcessResult { - COMPLETE, - TERMINATE, - PROCESSING, - MORE_DATA_REQUIRED, - NEED_SSL_HANDSHAKE, -}; - enum class NetworkProtocolType { POSTGRES_JDBC, POSTGRES_PSQL, diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index 9c1da5e88b1..e85ab57fccb 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -83,69 +83,41 @@ class Buffer { size_ = unprocessed_len; offset_ = 0; } +// +// void PrintBuffer(int offset, int len) { +// for (int i = offset; i < offset + len; i++) +// printf("%02X ", buf_[i]); +// printf("\n"); +// } - // TODO(Tianyu): Fix this after protocol refactor -// protected: + protected: size_t size_ = 0, offset_ = 0, capacity_; ByteBuf buf_; + private: friend class WriteQueue; + friend class PostgresPacketWriter; }; +namespace { +// Helper method for reading nul-terminated string for the read buffer +inline std::string ReadCString(ByteBuf::const_iterator begin, + ByteBuf::const_iterator end) { + // search for the nul terminator + for (ByteBuf::const_iterator head = begin; head != end; ++head) + if (*head == 0) return std::string(begin, head); + // No nul terminator found + throw NetworkProcessException("Expected nil in read buffer, none found"); +} +} + /** - * A buffer specialize for read + * A view of the read buffer that has its own read head. */ -class ReadBuffer : public Buffer { +class ReadBufferView { public: - /** - * Instantiates a new buffer and reserve capacity many bytes. - */ - inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) - : Buffer(capacity) {} - /** - * Read as many bytes as possible using SSL read - * @param context SSL context to read from - * @return the return value of ssl read - */ - inline int FillBufferFrom(SSL *context) { - ERR_clear_error(); - ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); - int err = SSL_get_error(context, bytes_read); - if (err == SSL_ERROR_NONE) size_ += bytes_read; - return err; - }; - - /** - * Read as many bytes as possible using Posix from an fd - * @param fd the file descriptor to read from - * @return the return value of posix read - */ - inline int FillBufferFrom(int fd) { - ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); - if (bytes_read > 0) size_ += bytes_read; - return (int) bytes_read; - } - - /** - * Read the specified amount of bytes off from another read buffer. The bytes - * will be consumed (cursor moved) on the other buffer and appended to the end - * of this buffer - * @param other The other buffer to read from - * @param size Number of bytes to read - */ - inline void FillBufferFrom(ReadBuffer &other, size_t size) { - other.Read(size, &buf_[size_]); - size_ += size; - } - - /** - * The number of bytes available to be consumed (i.e. meaningful bytes after - * current read cursor) - * @return The number of bytes available to be consumed - */ - inline size_t BytesAvailable() { return size_ - offset_; } - - + inline ReadBufferView(size_t size, ByteBuf::const_iterator begin) + : size_(size), begin_(begin) {} /** * Read the given number of bytes into destination, advancing cursor by that * number. It is up to the caller to ensure that there are enough bytes @@ -154,7 +126,7 @@ class ReadBuffer : public Buffer { * @param dest Desired memory location to read into */ inline void Read(size_t bytes, void *dest) { - std::copy(buf_.begin() + offset_, buf_.begin() + offset_ + bytes, + std::copy(begin_ + offset_, begin_ + offset_ + bytes, reinterpret_cast(dest)); offset_ += bytes; } @@ -168,7 +140,7 @@ class ReadBuffer : public Buffer { * @tparam T type of value to read off. Has to be size 1, 2, 4, or 8. * @return value of integer switched from network byte order */ - template + template inline T ReadValue() { // We only want to allow for certain type sizes to be used // After the static assert, the compiler should be smart enough to throw @@ -180,13 +152,10 @@ class ReadBuffer : public Buffer { auto val = ReadRawValue(); switch (sizeof(T)) { case 1: return val; - case 2: - return _CAST(T, be16toh(_CAST(uint16_t, val))); - case 4: - return _CAST(T, be32toh(_CAST(uint32_t, val))); - case 8: - return _CAST(T, be64toh(_CAST(uint64_t, val))); - // Will never be here due to compiler optimization + case 2:return _CAST(T, be16toh(_CAST(uint16_t, val))); + case 4:return _CAST(T, be32toh(_CAST(uint32_t, val))); + case 8:return _CAST(T, be64toh(_CAST(uint64_t, val))); + // Will never be here due to compiler optimization default: throw NetworkProcessException(""); } } @@ -196,19 +165,11 @@ class ReadBuffer : public Buffer { * if no nul-terminator is found within packet range. * @return string at head of read buffer */ - std::string ReadString() { - // search for the nul terminator - for (size_t i = offset_; i < size_; i++) { - if (buf_[i] == 0) { - auto result = std::string(buf_.begin() + offset_, - buf_.begin() + i); - // +1 because we want to skip nul - offset_ = i + 1; - return result; - } - } - // No nul terminator found - throw NetworkProcessException("Expected nil in read buffer, none found"); + inline std::string ReadString() { + std::string result = ReadCString(begin_ + offset_, begin_ + size_); + // extra byte of nul-terminator + offset_ += result.size() + 1; + return result; } /** @@ -216,7 +177,7 @@ class ReadBuffer : public Buffer { * @return string at head of read buffer */ inline std::string ReadString(size_t len) { - std::string result(buf_.begin() + offset_, buf_.begin() + offset_ + len); + std::string result(begin_ + offset_, begin_ + offset_ + len); offset_ += len; return result; } @@ -234,6 +195,94 @@ class ReadBuffer : public Buffer { Read(sizeof(result), &result); return result; } + + private: + size_t offset_ = 0, size_; + ByteBuf::const_iterator begin_; +}; + +/** + * A buffer specialize for read + */ +class ReadBuffer : public Buffer { + public: + /** + * Instantiates a new buffer and reserve capacity many bytes. + */ + inline ReadBuffer(size_t capacity = SOCKET_BUFFER_CAPACITY) + : Buffer(capacity) {} + /** + * Read as many bytes as possible using SSL read + * @param context SSL context to read from + * @return the return value of ssl read + */ + inline int FillBufferFrom(SSL *context) { + ERR_clear_error(); + ssize_t bytes_read = SSL_read(context, &buf_[size_], Capacity() - size_); + int err = SSL_get_error(context, bytes_read); + if (err == SSL_ERROR_NONE) size_ += bytes_read; + return err; + }; + + /** + * Read as many bytes as possible using Posix from an fd + * @param fd the file descriptor to read from + * @return the return value of posix read + */ + inline int FillBufferFrom(int fd) { + ssize_t bytes_read = read(fd, &buf_[size_], Capacity() - size_); + if (bytes_read > 0) size_ += bytes_read; + return (int) bytes_read; + } + + /** + * Read the specified amount of bytes off from another read buffer. The bytes + * will be consumed (cursor moved) on the other buffer and appended to the end + * of this buffer + * @param other The other buffer to read from + * @param size Number of bytes to read + */ + inline void FillBufferFrom(ReadBuffer &other, size_t size) { + other.ReadIntoView(size).Read(size, &buf_[size_]); + size_ += size; + } + + /** + * The number of bytes available to be consumed (i.e. meaningful bytes after + * current read cursor) + * @return The number of bytes available to be consumed + */ + inline size_t BytesAvailable() { return size_ - offset_; } + + /** + * Mark a chunk of bytes as read and return a view to the bytes read. + * + * This is necessary because a caller may not read all the bytes in a packet + * before exiting (exception occurs, etc.). Reserving a view of the bytes in + * a packet makes sure that the remaining bytes in a buffer is not malformed. + * + * No copying is performed in this process, however, so modifying the read buffer + * when a view is in scope will cause undefined behavior on the view's methods + * + * @param bytes number of butes to read + * @return a view of the bytes read. + */ + inline ReadBufferView ReadIntoView(size_t bytes) { + ReadBufferView result = ReadBufferView(bytes, buf_.begin() + offset_); + offset_ += bytes; + return result; + } + + template + inline T ReadValue() { + return ReadIntoView(sizeof(T)).ReadValue(); + } + + inline std::string ReadString() { + std::string result = ReadCString(buf_.begin() + offset_, buf_.begin() + size_); + offset_ += result.size() + 1; + return result; + } }; /** @@ -383,7 +432,8 @@ class WriteQueue { size_t written = breakup ? tail.RemainingCapacity() : 0; tail.AppendRaw(src, written); buffers_.push_back(std::make_shared()); - BufferWriteRaw(reinterpret_cast(src) + written, len - written); + BufferWriteRaw(reinterpret_cast(src) + written, + len - written); } } diff --git a/src/include/network/postgres_network_commands.h b/src/include/network/postgres_network_commands.h index 72348581ee0..59ba3a4000d 100644 --- a/src/include/network/postgres_network_commands.h +++ b/src/include/network/postgres_network_commands.h @@ -22,11 +22,11 @@ #define DEFINE_COMMAND(name, flush) \ class name : public PostgresNetworkCommand { \ public: \ - explicit name(std::shared_ptr in) \ - : PostgresNetworkCommand(std::move(in), flush) {} \ + explicit name(PostgresInputPacket &in) \ + : PostgresNetworkCommand(in, flush) {} \ virtual Transition Exec(PostgresProtocolInterpreter &, \ PostgresPacketWriter &, \ - CallbackFunc) override; \ + CallbackFunc) override; \ } namespace peloton { @@ -43,8 +43,8 @@ class PostgresNetworkCommand { inline bool FlushOnComplete() { return flush_on_complete_; } protected: - explicit PostgresNetworkCommand(std::shared_ptr in, bool flush) - : in_(std::move(in)), flush_on_complete_(flush) {} + explicit PostgresNetworkCommand(PostgresInputPacket &in, bool flush) + : in_(in.buf_->ReadIntoView(in.len_)), flush_on_complete_(flush) {} std::vector ReadParamTypes(); @@ -68,12 +68,11 @@ class PostgresNetworkCommand { std::vector ReadResultFormats(size_t tuple_size); - std::shared_ptr in_; + ReadBufferView in_; private: bool flush_on_complete_; }; -DEFINE_COMMAND(StartupCommand, true); DEFINE_COMMAND(SimpleQueryCommand, true); DEFINE_COMMAND(ParseCommand, false); DEFINE_COMMAND(BindCommand, false); diff --git a/src/include/network/postgres_protocol_interpreter.h b/src/include/network/postgres_protocol_interpreter.h index c2b73fd3bc2..db38e7cca38 100644 --- a/src/include/network/postgres_protocol_interpreter.h +++ b/src/include/network/postgres_protocol_interpreter.h @@ -48,14 +48,10 @@ class PostgresProtocolInterpreter : public ProtocolInterpreter { LOG_TRACE("PSQL result"); ExecQueryMessageGetResult(writer, status); } - - } - - inline void AddCmdlineOption(const std::string &key, std::string value) { - cmdline_options_[key] = std::move(value); } - inline void FinishStartup() { startup_ = false; } + Transition ProcessStartup(std::shared_ptr in, + std::shared_ptr out); inline tcop::ClientProcessState &ClientProcessState() { return state_; } diff --git a/src/include/network/postgres_protocol_utils.h b/src/include/network/postgres_protocol_utils.h index 22fa47b4870..b9f17b7c60b 100644 --- a/src/include/network/postgres_protocol_utils.h +++ b/src/include/network/postgres_protocol_utils.h @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include "network/network_io_utils.h" #include "common/statement.h" @@ -19,9 +20,22 @@ namespace peloton { namespace network { // TODO(Tianyu): It looks very broken that this never changes. -// TODO(Tianyu): Also, Initialize. +// clang-format off const std::unordered_map - parameter_status_map; + parameter_status_map = { + {"application_name", "psql"}, + {"client_encoding", "UTF8"}, + {"DateStyle", "ISO, MDY"}, + {"integer_datetimes", "on"}, + {"IntervalStyle", "postgres"}, + {"is_superuser", "on"}, + {"server_encoding", "UTF8"}, + {"server_version", "9.5devel"}, + {"session_authorization", "postgres"}, + {"standard_conforming_strings", "on"}, + {"TimeZone", "US/Eastern"} + }; +// clang-format on /** * Encapsulates an input packet diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index 9233d649f31..bc590293bc9 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -20,12 +20,9 @@ namespace peloton { namespace network { Transition NetworkIoWrapper::FlushAllWrites() { - for (auto buffer = out_->FlushHead(); - buffer != nullptr; - buffer = out_->FlushHead()) { - auto result = FlushWriteBuffer(*buffer); + for (; out_->FlushHead() != nullptr; out_->MarkHeadFlushed()) { + auto result = FlushWriteBuffer(*out_->FlushHead()); if (result != Transition::PROCEED) return result; - out_->MarkHeadFlushed(); } out_->Reset(); return Transition::PROCEED; @@ -63,12 +60,10 @@ Transition PosixSocketIoWrapper::FillReadBuffer() { case EAGAIN: // Equal to EWOULDBLOCK return result; - case EINTR: - continue; - default: - LOG_ERROR("Error writing: %s", strerror(errno)); + case EINTR:continue; + default:LOG_ERROR("Error writing: %s", strerror(errno)); throw NetworkProcessException("Error when filling read buffer " + - std::to_string(errno)); + std::to_string(errno)); } } return result; @@ -99,8 +94,8 @@ Transition SslSocketIoWrapper::FillReadBuffer() { case SSL_ERROR_NONE:result = Transition::PROCEED; break; case SSL_ERROR_ZERO_RETURN: return Transition::TERMINATE; - // The SSL packet is partially loaded to the SSL buffer only, - // More data is required in order to decode the wh`ole packet. + // The SSL packet is partially loaded to the SSL buffer only, + // More data is required in order to decode the wh`ole packet. case SSL_ERROR_WANT_READ: return result; case SSL_ERROR_WANT_WRITE: return Transition::NEED_WRITE; case SSL_ERROR_SYSCALL: diff --git a/src/network/postgres_network_commands.cpp b/src/network/postgres_network_commands.cpp index 2523d0c324a..db542055075 100644 --- a/src/network/postgres_network_commands.cpp +++ b/src/network/postgres_network_commands.cpp @@ -17,9 +17,6 @@ #include "settings/settings_manager.h" #include "planner/abstract_plan.h" -#define SSL_MESSAGE_VERNO 80877103 -#define PROTO_MAJOR_VERSION(x) ((x) >> 16) - namespace peloton { namespace network { @@ -29,17 +26,17 @@ namespace network { // project though, so I want to do the architectural refactor first. std::vector PostgresNetworkCommand::ReadParamTypes() { std::vector result; - auto num_params = in_->ReadValue(); + auto num_params = in_.ReadValue(); for (uint16_t i = 0; i < num_params; i++) - result.push_back(in_->ReadValue()); + result.push_back(in_.ReadValue()); return result; } std::vector PostgresNetworkCommand::ReadParamFormats() { std::vector result; - auto num_formats = in_->ReadValue(); + auto num_formats = in_.ReadValue(); for (uint16_t i = 0; i < num_formats; i++) - result.push_back(in_->ReadValue()); + result.push_back(in_.ReadValue()); return result; } @@ -48,9 +45,9 @@ void PostgresNetworkCommand::ReadParamValues(std::vector &bind_pa const std::vector ¶m_types, const std::vector< PostgresDataFormat> &formats) { - auto num_params = in_->ReadValue(); + auto num_params = in_.ReadValue(); for (uint16_t i = 0; i < num_params; i++) { - auto param_len = in_->ReadValue(); + auto param_len = in_.ReadValue(); if (param_len == -1) { // NULL auto peloton_type = PostgresValueTypeToPelotonValueType(param_types[i]); @@ -82,7 +79,7 @@ void PostgresNetworkCommand::ProcessTextParamValue(std::vector &b std::vector ¶m_values, PostgresValueType type, int32_t len) { - std::string val = in_->ReadString((size_t) len); + std::string val = in_.ReadString((size_t) len); bind_parameters.emplace_back(type::TypeId::VARCHAR, val); param_values.push_back( PostgresValueTypeToPelotonValueType(type) == type::TypeId::VARCHAR @@ -98,7 +95,7 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector switch (type) { case PostgresValueType::TINYINT: { PELOTON_ASSERT(len == sizeof(int8_t)); - auto val = in_->ReadValue(); + auto val = in_.ReadValue(); bind_parameters.emplace_back(type::TypeId::TINYINT, std::to_string(val)); param_values.push_back( type::ValueFactory::GetTinyIntValue(val).Copy()); @@ -106,7 +103,7 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector } case PostgresValueType::SMALLINT: { PELOTON_ASSERT(len == sizeof(int16_t)); - auto int_val = in_->ReadValue(); + auto int_val = in_.ReadValue(); bind_parameters.emplace_back(type::TypeId::SMALLINT, std::to_string(int_val)); param_values.push_back( @@ -115,7 +112,7 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector } case PostgresValueType::INTEGER: { PELOTON_ASSERT(len == sizeof(int32_t)); - auto val = in_->ReadValue(); + auto val = in_.ReadValue(); bind_parameters.emplace_back(type::TypeId::INTEGER, std::to_string(val)); param_values.push_back( type::ValueFactory::GetIntegerValue(val).Copy()); @@ -123,7 +120,7 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector } case PostgresValueType::BIGINT: { PELOTON_ASSERT(len == sizeof(int64_t)); - auto val = in_->ReadValue(); + auto val = in_.ReadValue(); bind_parameters.emplace_back(type::TypeId::BIGINT, std::to_string(val)); param_values.push_back( type::ValueFactory::GetBigIntValue(val).Copy()); @@ -131,14 +128,14 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector } case PostgresValueType::DOUBLE: { PELOTON_ASSERT(len == sizeof(double)); - auto val = in_->ReadValue(); + auto val = in_.ReadValue(); bind_parameters.emplace_back(type::TypeId::DECIMAL, std::to_string(val)); param_values.push_back( type::ValueFactory::GetDecimalValue(val).Copy()); break; } case PostgresValueType::VARBINARY: { - auto val = in_->ReadString((size_t) len); + auto val = in_.ReadString((size_t) len); bind_parameters.emplace_back(type::TypeId::VARBINARY, val); param_values.push_back( type::ValueFactory::GetVarbinaryValue( @@ -155,7 +152,7 @@ void PostgresNetworkCommand::ProcessBinaryParamValue(std::vector } std::vector PostgresNetworkCommand::ReadResultFormats(size_t tuple_size) { - auto num_format_codes = in_->ReadValue(); + auto num_format_codes = in_.ReadValue(); switch (num_format_codes) { case 0: // Default text mode @@ -163,63 +160,20 @@ std::vector PostgresNetworkCommand::ReadResultFormats(size_t PostgresDataFormat::TEXT); case 1: return std::vector(tuple_size, - in_->ReadValue()); + in_.ReadValue()); default:std::vector result; for (auto i = 0; i < num_format_codes; i++) - result.push_back(in_->ReadValue()); + result.push_back(in_.ReadValue()); return result; } } -Transition StartupCommand::Exec(PostgresProtocolInterpreter &interpreter, - PostgresPacketWriter &out, - CallbackFunc) { - tcop::ClientProcessState &state = interpreter.ClientProcessState(); - auto proto_version = in_->ReadValue(); - LOG_INFO("protocol version: %d", proto_version); - // SSL initialization - if (proto_version == SSL_MESSAGE_VERNO) { - // TODO(Tianyu): Should this be moved from PelotonServer into settings? - if (PelotonServer::GetSSLLevel() == SSLLevel::SSL_DISABLE) { - out.WriteSingleTypePacket(NetworkMessageType::SSL_NO); - return Transition::PROCEED; - } - out.WriteSingleTypePacket(NetworkMessageType::SSL_YES); - return Transition::NEED_SSL_HANDSHAKE; - } - - // Process startup packet - if (PROTO_MAJOR_VERSION(proto_version) != 3) { - LOG_ERROR("Protocol error: only protocol version 3 is supported"); - out.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, - "Protocol Version Not Supported"}}); - return Transition::TERMINATE; - } - - // The last bit of the packet will be nul. This is not a valid field. When there - // is less than 2 bytes of data remaining we can already exit early. - while (in_->HasMore(2)) { - // TODO(Tianyu): We don't seem to really handle the other flags? - std::string key = in_->ReadString(), value = in_->ReadString(); - LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); - if (key == std::string("database")) - state.db_name_ = value; - interpreter.AddCmdlineOption(key, std::move(value)); - } - // skip the last nul byte - in_->Skip(1); - // TODO(Tianyu): Implement authentication. For now we always send AuthOK - out.WriteStartupResponse(); - interpreter.FinishStartup(); - return Transition::PROCEED; -} - Transition SimpleQueryCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc callback) { interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_PSQL; tcop::ClientProcessState &state = interpreter.ClientProcessState(); - std::string query = in_->ReadString(); + std::string query = in_.ReadString(); LOG_TRACE("Execute query: %s", query.c_str()); std::unique_ptr sql_stmt_list; try { @@ -350,8 +304,7 @@ Transition ParseCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); - std::string statement_name = in_->ReadString(), query = in_->ReadString(); - + std::string statement_name = in_.ReadString(), query = in_.ReadString(); // In JDBC, one query starts with parsing stage. // Reset skipped_stmt_ to false for the new query. state.skipped_stmt_ = false; @@ -415,15 +368,15 @@ Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); - std::string portal_name = in_->ReadString(), - statement_name = in_->ReadString(); + std::string portal_name = in_.ReadString(), + statement_name = in_.ReadString(); if (state.skipped_stmt_) { out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); return Transition::PROCEED; } - std::vector formats = ReadParamFormats(); + // Get statement info generated in PARSE message std::shared_ptr statement = state.statement_cache_.GetStatement(statement_name); @@ -473,8 +426,6 @@ Transition BindCommand::Exec(PostgresProtocolInterpreter &interpreter, // Instead of tree traversal, we should put param values in the // executor context. - - interpreter.portals_[portal_name] = std::make_shared(portal_name, statement, std::move(param_values)); out.WriteSingleTypePacket(NetworkMessageType::BIND_COMPLETE); @@ -491,8 +442,8 @@ Transition DescribeCommand::Exec(PostgresProtocolInterpreter &interpreter, return Transition::PROCEED; } - auto mode = in_->ReadValue(); - std::string portal_name = in_->ReadString(); + auto mode = in_.ReadValue(); + std::string portal_name = in_.ReadString(); switch (mode) { case PostgresNetworkObjectType::PORTAL: { LOG_TRACE("Describe a portal"); @@ -524,7 +475,10 @@ Transition ExecuteCommand::Exec(PostgresProtocolInterpreter &interpreter, CallbackFunc callback) { interpreter.protocol_type_ = NetworkProtocolType::POSTGRES_JDBC; tcop::ClientProcessState &state = interpreter.ClientProcessState(); - std::string portal_name = in_->ReadString(); + std::string portal_name = in_.ReadString(); + // We never seem to use this row limit field in the message? + auto row_limit = in_.ReadValue(); + (void) row_limit; // covers weird JDBC edge case of sending double BEGIN statements. Don't // execute them @@ -569,8 +523,8 @@ Transition CloseCommand::Exec(PostgresProtocolInterpreter &interpreter, PostgresPacketWriter &out, CallbackFunc) { tcop::ClientProcessState &state = interpreter.ClientProcessState(); - auto close_type = in_->ReadValue(); - std::string name = in_->ReadString(); + auto close_type = in_.ReadValue(); + std::string name = in_.ReadString(); switch (close_type) { case PostgresNetworkObjectType::STATEMENT: { LOG_TRACE("Deleting statement %s from cache", name.c_str()); diff --git a/src/network/postgres_protocol_interpreter.cpp b/src/network/postgres_protocol_interpreter.cpp index 22ed4bca839..00cc732bd8f 100644 --- a/src/network/postgres_protocol_interpreter.cpp +++ b/src/network/postgres_protocol_interpreter.cpp @@ -12,10 +12,13 @@ #include "planner/plan_util.h" #include "network/postgres_protocol_interpreter.h" +#include "network/peloton_server.h" #define MAKE_COMMAND(type) \ std::static_pointer_cast( \ - std::make_shared(std::move(curr_input_packet_.buf_))) + std::make_shared(curr_input_packet_)) +#define SSL_MESSAGE_VERNO 80877103 +#define PROTO_MAJOR_VERSION(x) ((x) >> 16) namespace peloton { namespace network { @@ -23,6 +26,12 @@ Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, std::shared_ptr out, CallbackFunc callback) { if (!TryBuildPacket(in)) return Transition::NEED_READ; + if (startup_) { + // Always flush startup packet response + out->ForceFlush(); + curr_input_packet_.Clear(); + return ProcessStartup(in, out); + } std::shared_ptr command = PacketToCommand(); curr_input_packet_.Clear(); PostgresPacketWriter writer(*out); @@ -30,6 +39,48 @@ Transition PostgresProtocolInterpreter::Process(std::shared_ptr in, return command->Exec(*this, writer, callback); } +Transition PostgresProtocolInterpreter::ProcessStartup(std::shared_ptr in, + std::shared_ptr out) { + PostgresPacketWriter writer(*out); + auto proto_version = in->ReadValue(); + LOG_INFO("protocol version: %d", proto_version); + // SSL initialization + if (proto_version == SSL_MESSAGE_VERNO) { + // TODO(Tianyu): Should this be moved from PelotonServer into settings? + if (PelotonServer::GetSSLLevel() == SSLLevel::SSL_DISABLE) { + writer.WriteSingleTypePacket(NetworkMessageType::SSL_NO); + return Transition::PROCEED; + } + writer.WriteSingleTypePacket(NetworkMessageType::SSL_YES); + return Transition::NEED_SSL_HANDSHAKE; + } + + // Process startup packet + if (PROTO_MAJOR_VERSION(proto_version) != 3) { + LOG_ERROR("Protocol error: only protocol version 3 is supported"); + writer.WriteErrorResponse({{NetworkMessageType::HUMAN_READABLE_ERROR, + "Protocol Version Not Supported"}}); + return Transition::TERMINATE; + } + + // The last bit of the packet will be nul. This is not a valid field. When there + // is less than 2 bytes of data remaining we can already exit early. + while (in->HasMore(2)) { + // TODO(Tianyu): We don't seem to really handle the other flags? + std::string key = in->ReadString(), value = in->ReadString(); + LOG_TRACE("Option key %s, value %s", key.c_str(), value.c_str()); + if (key == std::string("database")) + state_.db_name_ = value; + cmdline_options_[key] = std::move(value); + } + // skip the last nul byte + in->Skip(1); + // TODO(Tianyu): Implement authentication. For now we always send AuthOK + writer.WriteStartupResponse(); + startup_ = false; + return Transition::PROCEED; +} + bool PostgresProtocolInterpreter::TryBuildPacket(std::shared_ptr &in) { if (!TryReadPacketHeader(in)) return false; @@ -57,7 +108,7 @@ bool PostgresProtocolInterpreter::TryReadPacketHeader(std::shared_ptrReadRawValue(); + curr_input_packet_.msg_type_ = in->ReadValue(); curr_input_packet_.len_ = in->ReadValue() - sizeof(uint32_t); // Extend the buffer as needed @@ -77,7 +128,6 @@ bool PostgresProtocolInterpreter::TryReadPacketHeader(std::shared_ptr PostgresProtocolInterpreter::PacketToCommand() { - if (startup_) return MAKE_COMMAND(StartupCommand); switch (curr_input_packet_.msg_type_) { case NetworkMessageType::SIMPLE_QUERY_COMMAND: return MAKE_COMMAND(SimpleQueryCommand); From d508fb3ca8be38a02d4ada088a3ed7ad1ce80e6d Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Wed, 4 Jul 2018 04:59:40 -0400 Subject: [PATCH 41/48] Remove dead code --- src/include/network/network_io_utils.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/include/network/network_io_utils.h b/src/include/network/network_io_utils.h index e85ab57fccb..d6544771121 100644 --- a/src/include/network/network_io_utils.h +++ b/src/include/network/network_io_utils.h @@ -83,12 +83,6 @@ class Buffer { size_ = unprocessed_len; offset_ = 0; } -// -// void PrintBuffer(int offset, int len) { -// for (int i = offset; i < offset + len; i++) -// printf("%02X ", buf_[i]); -// printf("\n"); -// } protected: size_t size_ = 0, offset_ = 0, capacity_; From 75b2a5579eef769f252403a3bb0edfc490ccef24 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Wed, 4 Jul 2018 06:25:50 -0400 Subject: [PATCH 42/48] Fix bug of invalid state of Sync command by changing the initial state of txn_state_ to IDLE. --- src/include/traffic_cop/tcop.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index a9e712c0312..b22fe49ba13 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -43,7 +43,7 @@ struct ClientProcessState { // flag of single statement txn std::vector result_; std::stack tcop_txn_state_; - NetworkTransactionStateType txn_state_ = NetworkTransactionStateType::INVALID; + NetworkTransactionStateType txn_state_ = NetworkTransactionStateType::IDLE; bool skipped_stmt_ = false; std::string skipped_query_string_; QueryType skipped_query_type_ = QueryType::QUERY_INVALID; @@ -74,7 +74,7 @@ struct ClientProcessState { result_format_.clear(); result_.clear(); tcop_txn_state_ = std::stack(); - txn_state_ = NetworkTransactionStateType::INVALID; + txn_state_ = NetworkTransactionStateType::IDLE; skipped_stmt_ = false; skipped_query_string_ = ""; skipped_query_type_ = QueryType::QUERY_INVALID; From 0ba72d4431cee20827414669e5a3aa03e2831829 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 4 Jul 2018 08:58:28 -0400 Subject: [PATCH 43/48] Restore Junit tests --- script/testing/{ => junit}/InsertTPCCTest.java | 0 script/testing/{ => junit}/InsertTest.java | 0 script/testing/{ => junit}/PLTestBase.java | 0 script/testing/{ => junit}/UpdateTest.java | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename script/testing/{ => junit}/InsertTPCCTest.java (100%) rename script/testing/{ => junit}/InsertTest.java (100%) rename script/testing/{ => junit}/PLTestBase.java (100%) rename script/testing/{ => junit}/UpdateTest.java (100%) diff --git a/script/testing/InsertTPCCTest.java b/script/testing/junit/InsertTPCCTest.java similarity index 100% rename from script/testing/InsertTPCCTest.java rename to script/testing/junit/InsertTPCCTest.java diff --git a/script/testing/InsertTest.java b/script/testing/junit/InsertTest.java similarity index 100% rename from script/testing/InsertTest.java rename to script/testing/junit/InsertTest.java diff --git a/script/testing/PLTestBase.java b/script/testing/junit/PLTestBase.java similarity index 100% rename from script/testing/PLTestBase.java rename to script/testing/junit/PLTestBase.java diff --git a/script/testing/UpdateTest.java b/script/testing/junit/UpdateTest.java similarity index 100% rename from script/testing/UpdateTest.java rename to script/testing/junit/UpdateTest.java From a384dc34e6c7961f0b2e1ba1b026fecd8754e4c7 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 4 Jul 2018 09:19:56 -0400 Subject: [PATCH 44/48] Restore peloton startup in junit --- script/testing/junit/run_junit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/script/testing/junit/run_junit.py b/script/testing/junit/run_junit.py index fd03fd74cb9..b628cfe2d6b 100755 --- a/script/testing/junit/run_junit.py +++ b/script/testing/junit/run_junit.py @@ -101,12 +101,12 @@ def _run_junit(self): def run(self): """ Orchestrate the overall JUnit test execution """ - # self._check_peloton_binary() - # self._run_peloton() + self._check_peloton_binary() + self._run_peloton() ret_val = self._run_junit() self._print_output(self.junit_output_file) - # self._stop_peloton() + self._stop_peloton() if ret_val: # print the peloton log file, only if we had a failure self._print_output(self.peloton_output_file) From a4c17d883e5ddace40162bb7c62911a59285246d Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 4 Jul 2018 15:35:51 -0400 Subject: [PATCH 45/48] Fix LOG_TRACE compilation --- src/traffic_cop/tcop.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index c2212f2eca7..be9194b703c 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -372,7 +372,7 @@ void Tcop::ExecuteStatementPlanGetResult(ClientProcessState &state) { // Abort LOG_TRACE("Abort Transaction"); if (state.single_statement_txn_) { - LOG_TRACE("Tcop_txn_state size: %lu", tcop_txn_state_.size()); + LOG_TRACE("Tcop_txn_state size: %lu", state.tcop_txn_state_.size()); state.p_status_.m_result = AbortQueryHelper(state); } else { state.tcop_txn_state_.top().second = ResultType::ABORTED; From 6d1d86b63a6c4a5d85dbb4695d689b152498a482 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Wed, 4 Jul 2018 16:40:11 -0400 Subject: [PATCH 46/48] Actually fix LOG_TRACE --- src/traffic_cop/tcop.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traffic_cop/tcop.cpp b/src/traffic_cop/tcop.cpp index be9194b703c..20af77b4d15 100644 --- a/src/traffic_cop/tcop.cpp +++ b/src/traffic_cop/tcop.cpp @@ -384,7 +384,7 @@ void Tcop::ExecuteStatementPlanGetResult(ClientProcessState &state) { ResultType Tcop::ExecuteStatementGetResult(ClientProcessState &state) { LOG_TRACE("Statement executed. Result: %s", - ResultTypeToString(p_status_.m_result).c_str()); + ResultTypeToString(state.p_status_.m_result).c_str()); state.rows_affected_ = state.p_status_.m_processed; LOG_TRACE("rows_changed %d", state.p_status_.m_processed); state.is_queuing_ = false; From cd16c71812262517cb215d9da26f24e9ce650b78 Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Thu, 5 Jul 2018 20:51:44 -0400 Subject: [PATCH 47/48] Reset Optimizer instead of create a new one --- src/include/traffic_cop/tcop.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/traffic_cop/tcop.h b/src/include/traffic_cop/tcop.h index b22fe49ba13..ef14179a159 100644 --- a/src/include/traffic_cop/tcop.h +++ b/src/include/traffic_cop/tcop.h @@ -69,7 +69,7 @@ struct ClientProcessState { db_name_ = DEFAULT_DB_NAME; param_values_.clear(); statement_.reset(); - optimizer_.reset(new optimizer::Optimizer()); + optimizer_->Reset(); single_statement_txn_ = false; result_format_.clear(); result_.clear(); From 2dca58d8964b318d9b9029096ad9b6a545feeaf4 Mon Sep 17 00:00:00 2001 From: Tianyu Li Date: Fri, 6 Jul 2018 16:02:16 -0400 Subject: [PATCH 48/48] Revert to ConnectionHandleFactory to solve occasional memory leak on exit. --- src/include/network/connection_handle.h | 13 ++-- .../network/connection_handle_factory.h | 53 +++++++++++++ .../network/network_io_wrapper_factory.h | 66 ---------------- src/include/network/network_io_wrappers.h | 32 +++++--- src/network/connection_handle.cpp | 38 +++++++-- src/network/connection_handle_factory.cpp | 40 ++++++++++ src/network/connection_handler_task.cpp | 13 +--- src/network/network_io_wrapper_factory.cpp | 78 ------------------- src/network/network_io_wrappers.cpp | 2 +- test/network/exception_test.cpp | 2 +- test/network/prepare_stmt_test.cpp | 2 +- test/network/select_all_test.cpp | 2 +- test/network/simple_query_test.cpp | 2 +- test/network/ssl_test.cpp | 2 +- 14 files changed, 164 insertions(+), 181 deletions(-) create mode 100644 src/include/network/connection_handle_factory.h delete mode 100644 src/include/network/network_io_wrapper_factory.h create mode 100644 src/network/connection_handle_factory.cpp delete mode 100644 src/network/network_io_wrapper_factory.cpp diff --git a/src/include/network/connection_handle.h b/src/include/network/connection_handle.h index 700e1100164..dbc3605ee5f 100644 --- a/src/include/network/connection_handle.h +++ b/src/include/network/connection_handle.h @@ -51,6 +51,7 @@ namespace network { */ class ConnectionHandle { public: + /** * Constructs a new ConnectionHandle * @param sock_fd Client's connection fd @@ -58,6 +59,8 @@ class ConnectionHandle { */ ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler); + DISALLOW_COPY_AND_MOVE(ConnectionHandle); + /** * @brief Signal to libevent that this ConnectionHandle is ready to handle * events @@ -179,15 +182,15 @@ class ConnectionHandle { }; friend class StateMachine; - friend class NetworkIoWrapperFactory; + friend class ConnectionHandleFactory; + // A raw pointer is used here because references cannot be rebound. ConnectionHandlerTask *conn_handler_; - std::shared_ptr io_wrapper_; - StateMachine state_machine_; - struct event *network_event_ = nullptr, *workpool_event_ = nullptr; - + std::unique_ptr io_wrapper_; // TODO(Tianyu): Probably use a factory for this std::unique_ptr protocol_interpreter_; + StateMachine state_machine_{}; + struct event *network_event_ = nullptr, *workpool_event_ = nullptr; }; } // namespace network } // namespace peloton diff --git a/src/include/network/connection_handle_factory.h b/src/include/network/connection_handle_factory.h new file mode 100644 index 00000000000..a1c5d438872 --- /dev/null +++ b/src/include/network/connection_handle_factory.h @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// connection_handle_factory.h +// +// Identification: src/include/network/connection_handle_factory.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "network/connection_handle.h" +#include "network/peloton_server.h" + +namespace peloton { +namespace network { + +/** + * @brief Factory class for constructing ConnectionHandle objects + * Each ConnectionHandle is associated with read and write buffers that are + * expensive to reallocate on the fly. Thus, instead of destroying these wrapper + * objects when they are out of scope, we save them until we can transfer their + * buffers to other wrappers. + */ +// TODO(Tianyu): Additionally, it is hard to make sure the ConnectionHandles +// don't leak without this factory since they are essentially managed by +// libevent if nothing in our system holds reference to them, and libevent +// doesn't cleanup raw pointers. +class ConnectionHandleFactory { + public: + static inline ConnectionHandleFactory &GetInstance() { + static ConnectionHandleFactory factory; + return factory; + } + + /** + * @brief Creates or re-purpose a NetworkIoWrapper object for new use. + * The returned value always uses Posix I/O methods unles explicitly + * converted. + * @see NetworkIoWrapper for details + * @param conn_fd Client connection fd + * @return A new NetworkIoWrapper object + */ + ConnectionHandle &NewConnectionHandle(int conn_fd, ConnectionHandlerTask *task); + + private: + std::unordered_map reusable_handles_; +}; +} // namespace network +} // namespace peloton diff --git a/src/include/network/network_io_wrapper_factory.h b/src/include/network/network_io_wrapper_factory.h deleted file mode 100644 index d5170fe4202..00000000000 --- a/src/include/network/network_io_wrapper_factory.h +++ /dev/null @@ -1,66 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// network_io_wrapper_factory.h -// -// Identification: src/include/network/network_io_wrapper_factory.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "network/network_io_wrappers.h" -#include "network/peloton_server.h" - -namespace peloton { -namespace network { - -/** - * @brief Factory class for constructing NetworkIoWrapper objects - * Each NetworkIoWrapper is associated with read and write buffers that are - * expensive to reallocate on the fly. Thus, instead of destroying these wrapper - * objects when they are out of scope, we save them until we can transfer their - * buffers to other wrappers. - */ -// TODO(Tianyu): Make reuse more fine-grained and adjustable -// Currently there is no limit on the number of wrappers we save. This means -// that we never deallocated wrappers unless we shut down. Obviously this will -// be a memory overhead if we had a lot of connections at one point and dropped -// down after a while. Relying on OS fd values for reuse also can backfire. It -// shouldn't be hard to keep a pool of buffers with a size limit instead of a -// bunch of old wrapper objects. -class NetworkIoWrapperFactory { - public: - static inline NetworkIoWrapperFactory &GetInstance() { - static NetworkIoWrapperFactory factory; - return factory; - } - - /** - * @brief Creates or re-purpose a NetworkIoWrapper object for new use. - * The returned value always uses Posix I/O methods unles explicitly - * converted. - * @see NetworkIoWrapper for details - * @param conn_fd Client connection fd - * @return A new NetworkIoWrapper object - */ - std::shared_ptr NewNetworkIoWrapper(int conn_fd); - - /** - * @brief: process SSL handshake to generate valid SSL - * connection context for further communications - * @return FINISH when the SSL handshake failed - * PROCEED when the SSL handshake success - * NEED_DATA when the SSL handshake is partially done due to network - * latency - */ - Transition TryUseSsl(std::shared_ptr &io_wrapper); - - private: - std::unordered_map> reusable_wrappers_; -}; -} // namespace network -} // namespace peloton diff --git a/src/include/network/network_io_wrappers.h b/src/include/network/network_io_wrappers.h index 661002c5979..4fe8107cfa1 100644 --- a/src/include/network/network_io_wrappers.h +++ b/src/include/network/network_io_wrappers.h @@ -36,7 +36,6 @@ namespace network { * class. @see NetworkIoWrapperFactory */ class NetworkIoWrapper { - friend class NetworkIoWrapperFactory; public: virtual bool SslAble() const = 0; // TODO(Tianyu): Change and document after we refactor protocol handler @@ -51,8 +50,9 @@ class NetworkIoWrapper { inline bool ShouldFlush() { return out_->ShouldFlush(); } // TODO(Tianyu): Make these protected when protocol handler refactor is // complete - NetworkIoWrapper(int sock_fd, std::shared_ptr &in, - std::shared_ptr &out) + NetworkIoWrapper(int sock_fd, + std::shared_ptr in, + std::shared_ptr out) : sock_fd_(sock_fd), in_(std::move(in)), out_(std::move(out)) { @@ -60,9 +60,12 @@ class NetworkIoWrapper { out_->Reset(); } - DISALLOW_COPY(NetworkIoWrapper) + DISALLOW_COPY(NetworkIoWrapper); - NetworkIoWrapper(NetworkIoWrapper &&other) = default; + NetworkIoWrapper(NetworkIoWrapper &&other) noexcept + : NetworkIoWrapper(other.sock_fd_, + std::move(other.in_), + std::move(other.out_)) {} int sock_fd_; std::shared_ptr in_; @@ -74,9 +77,18 @@ class NetworkIoWrapper { */ class PosixSocketIoWrapper : public NetworkIoWrapper { public: - PosixSocketIoWrapper(int sock_fd, std::shared_ptr in, - std::shared_ptr out); + explicit PosixSocketIoWrapper(int sock_fd, + std::shared_ptr in = + std::make_shared(), + std::shared_ptr out = + std::make_shared()); + explicit PosixSocketIoWrapper(NetworkIoWrapper &&other) + : PosixSocketIoWrapper(other.sock_fd_, + std::move(other.in_), + std::move(other.out_)) {} + + DISALLOW_COPY_AND_MOVE(PosixSocketIoWrapper); inline bool SslAble() const override { return false; } Transition FillReadBuffer() override; @@ -95,7 +107,9 @@ class SslSocketIoWrapper : public NetworkIoWrapper { // Realistically, an SslSocketIoWrapper is always derived from a // PosixSocketIoWrapper, as the handshake process happens over posix sockets. SslSocketIoWrapper(NetworkIoWrapper &&other, SSL *ssl) - : NetworkIoWrapper(std::move(other)), conn_ssl_context_(ssl) {} + : NetworkIoWrapper(std::move(other)), conn_ssl_context_(ssl) {} + + DISALLOW_COPY_AND_MOVE(SslSocketIoWrapper); inline bool SslAble() const override { return true; } Transition FillReadBuffer() override; @@ -103,7 +117,7 @@ class SslSocketIoWrapper : public NetworkIoWrapper { Transition Close() override; private: - friend class NetworkIoWrapperFactory; + friend class ConnectionHandle; SSL *conn_ssl_context_; }; } // namespace network diff --git a/src/network/connection_handle.cpp b/src/network/connection_handle.cpp index 5bb9d27d913..7421db2c3af 100644 --- a/src/network/connection_handle.cpp +++ b/src/network/connection_handle.cpp @@ -15,7 +15,7 @@ #include "network/connection_dispatcher_task.h" #include "network/connection_handle.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" #include "network/peloton_server.h" #include "common/utility.h" @@ -167,9 +167,10 @@ void ConnectionHandle::StateMachine::Accept(Transition action, // TODO(Tianyu): Maybe use a factory to initialize protocol_interpreter here ConnectionHandle::ConnectionHandle(int sock_fd, ConnectionHandlerTask *handler) : conn_handler_(handler), - io_wrapper_(NetworkIoWrapperFactory::GetInstance().NewNetworkIoWrapper(sock_fd)), + io_wrapper_{new PosixSocketIoWrapper(sock_fd)}, protocol_interpreter_{new PostgresProtocolInterpreter(conn_handler_->Id())} {} + Transition ConnectionHandle::GetResult() { EventUtil::EventAdd(network_event_, nullptr); protocol_interpreter_->GetResult(io_wrapper_->GetWriteQueue()); @@ -180,8 +181,33 @@ Transition ConnectionHandle::TrySslHandshake() { // TODO(Tianyu): Do we really need to flush here? auto ret = io_wrapper_->FlushAllWrites(); if (ret != Transition::PROCEED) return ret; - return NetworkIoWrapperFactory::GetInstance().TryUseSsl( - io_wrapper_); + SSL *context; + if (!io_wrapper_->SslAble()) { + context = SSL_new(PelotonServer::ssl_context); + if (context == nullptr) + throw NetworkProcessException("ssl context for conn failed"); + SSL_set_session_id_context(context, nullptr, 0); + if (SSL_set_fd(context, io_wrapper_->sock_fd_) == 0) + throw NetworkProcessException("Failed to set ssl fd"); + io_wrapper_.reset(new SslSocketIoWrapper(std::move(*io_wrapper_), context)); + } else + context = dynamic_cast(io_wrapper_.get())->conn_ssl_context_; + + // The wrapper already uses SSL methods. + // Yuchen: "Post-connection verification?" + ERR_clear_error(); + int ssl_accept_ret = SSL_accept(context); + if (ssl_accept_ret > 0) return Transition::PROCEED; + + int err = SSL_get_error(context, ssl_accept_ret); + switch (err) { + case SSL_ERROR_WANT_READ: + return Transition::NEED_READ; + case SSL_ERROR_WANT_WRITE: + return Transition::NEED_WRITE; + default: + throw NetworkProcessException("SSL Error, error code" + std::to_string(err)); + } } Transition ConnectionHandle::TryCloseConnection() { @@ -195,10 +221,6 @@ Transition ConnectionHandle::TryCloseConnection() { // connection handle and we will need to destruct and exit. conn_handler_->UnregisterEvent(network_event_); conn_handler_->UnregisterEvent(workpool_event_); - // This object is essentially managed by libevent (which unfortunately does - // not accept shared_ptrs.) and thus as we shut down we need to manually - // deallocate this object. - delete this; return Transition::NONE; } } // namespace network diff --git a/src/network/connection_handle_factory.cpp b/src/network/connection_handle_factory.cpp new file mode 100644 index 00000000000..5eafcfdf7e0 --- /dev/null +++ b/src/network/connection_handle_factory.cpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// connection_handle_factory.cpp +// +// Identification: src/network/connection_handle_factory.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include +#include "network/connection_handle_factory.h" + +namespace peloton { +namespace network { +ConnectionHandle &ConnectionHandleFactory::NewConnectionHandle(int conn_fd, ConnectionHandlerTask *task) { + auto it = reusable_handles_.find(conn_fd); + if (it == reusable_handles_.end()) { + auto ret = reusable_handles_.emplace(std::piecewise_construct, + std::forward_as_tuple(conn_fd), + std::forward_as_tuple(conn_fd, task)); + PELOTON_ASSERT(ret.second); + return ret.first->second; + } + + auto &reused_handle= it->second; + reused_handle.conn_handler_ = task; + reused_handle.io_wrapper_.reset(new PosixSocketIoWrapper(std::move( + *reused_handle.io_wrapper_.release()))); + reused_handle.protocol_interpreter_.reset(new PostgresProtocolInterpreter(task->Id())); + reused_handle.state_machine_= ConnectionHandle::StateMachine(); + PELOTON_ASSERT(reused_handle.network_event_ == nullptr); + PELOTON_ASSERT(reused_handle.workpool_event_ == nullptr); + return reused_handle; +} +} // namespace network +} // namespace peloton diff --git a/src/network/connection_handler_task.cpp b/src/network/connection_handler_task.cpp index 7d5a5114c78..f2e01dc66bb 100644 --- a/src/network/connection_handler_task.cpp +++ b/src/network/connection_handler_task.cpp @@ -12,7 +12,7 @@ #include "network/connection_handler_task.h" #include "network/connection_handle.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" namespace peloton { namespace network { @@ -52,14 +52,9 @@ void ConnectionHandlerTask::HandleDispatch(int new_conn_recv_fd, short) { } bytes_read += (size_t)result; } - - // Smart pointers are not used here because libevent does not take smart - // pointers. During the life time of this object, the pointer to it will be - // maintained by libevent rather than by our own code. The object will have to - // be cleaned up by one of its methods (i.e. we call a method with "delete - // this" and have the object commit suicide from libevent. ) - (new ConnectionHandle(*reinterpret_cast(client_fd), this)) - ->RegisterToReceiveEvents(); + ConnectionHandleFactory::GetInstance() + .NewConnectionHandle(*reinterpret_cast(client_fd), this) + .RegisterToReceiveEvents(); } } // namespace network diff --git a/src/network/network_io_wrapper_factory.cpp b/src/network/network_io_wrapper_factory.cpp deleted file mode 100644 index 2a1edd0a6f5..00000000000 --- a/src/network/network_io_wrapper_factory.cpp +++ /dev/null @@ -1,78 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// network_io_wrapper_factory.cpp -// -// Identification: src/network/network_io_wrapper_factory.cpp -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - -#include -#include "network/network_io_wrapper_factory.h" - -namespace peloton { -namespace network { -std::shared_ptr NetworkIoWrapperFactory::NewNetworkIoWrapper( - int conn_fd) { - auto it = reusable_wrappers_.find(conn_fd); - if (it == reusable_wrappers_.end()) { - // No reusable wrappers - auto wrapper = std::make_shared( - conn_fd, std::make_shared(), - std::make_shared()); - reusable_wrappers_[conn_fd] = - std::static_pointer_cast(wrapper); - return wrapper; - } - - // Construct new wrapper by reusing buffers from the old one. - // The old one will be deallocated as we replace the last reference to it - // in the reusable_wrappers_ map. We still need to explicitly call the - // constructor so the flags are set properly on the new file descriptor. - auto &reused_wrapper = it->second; - reused_wrapper = std::make_shared(conn_fd, - reused_wrapper->in_, - reused_wrapper->out_); - return reused_wrapper; -} - -Transition NetworkIoWrapperFactory::TryUseSsl( - std::shared_ptr &io_wrapper) { - SSL *context; - if (!io_wrapper->SslAble()) { - context = SSL_new(PelotonServer::ssl_context); - if (context == nullptr) - throw NetworkProcessException("ssl context for conn failed"); - SSL_set_session_id_context(context, nullptr, 0); - if (SSL_set_fd(context, io_wrapper->sock_fd_) == 0) - throw NetworkProcessException("Failed to set ssl fd"); - io_wrapper = - std::make_shared(std::move(*io_wrapper), context); - reusable_wrappers_[io_wrapper->sock_fd_] = io_wrapper; - } else { - auto ptr = std::dynamic_pointer_cast( - io_wrapper); - context = ptr->conn_ssl_context_; - } - - // The wrapper already uses SSL methods. - // Yuchen: "Post-connection verification?" - ERR_clear_error(); - int ssl_accept_ret = SSL_accept(context); - if (ssl_accept_ret > 0) return Transition::PROCEED; - - int err = SSL_get_error(context, ssl_accept_ret); - switch (err) { - case SSL_ERROR_WANT_READ: - return Transition::NEED_READ; - case SSL_ERROR_WANT_WRITE: - return Transition::NEED_WRITE; - default: - throw NetworkProcessException("SSL Error, error code" + std::to_string(err)); - } -} -} // namespace network -} // namespace peloton diff --git a/src/network/network_io_wrappers.cpp b/src/network/network_io_wrappers.cpp index bc590293bc9..4dcaa76150d 100644 --- a/src/network/network_io_wrappers.cpp +++ b/src/network/network_io_wrappers.cpp @@ -31,7 +31,7 @@ Transition NetworkIoWrapper::FlushAllWrites() { PosixSocketIoWrapper::PosixSocketIoWrapper(int sock_fd, std::shared_ptr in, std::shared_ptr out) - : NetworkIoWrapper(sock_fd, in, out) { + : NetworkIoWrapper(sock_fd, std::move(in), std::move(out)) { // Set Non Blocking auto flags = fcntl(sock_fd_, F_GETFL); diff --git a/test/network/exception_test.cpp b/test/network/exception_test.cpp index 3120f79e063..5bdef00ce87 100644 --- a/test/network/exception_test.cpp +++ b/test/network/exception_test.cpp @@ -16,7 +16,7 @@ #include "common/harness.h" #include "common/logger.h" #include "gtest/gtest.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" #include "network/peloton_server.h" #include "util/string_util.h" diff --git a/test/network/prepare_stmt_test.cpp b/test/network/prepare_stmt_test.cpp index cccd19abd78..0fb0a247d13 100644 --- a/test/network/prepare_stmt_test.cpp +++ b/test/network/prepare_stmt_test.cpp @@ -16,7 +16,7 @@ #include "gtest/gtest.h" #include "network/peloton_server.h" #include "util/string_util.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" namespace peloton { namespace test { diff --git a/test/network/select_all_test.cpp b/test/network/select_all_test.cpp index 9c50b650425..7d6dfd187df 100644 --- a/test/network/select_all_test.cpp +++ b/test/network/select_all_test.cpp @@ -14,7 +14,7 @@ #include "gtest/gtest.h" #include "common/logger.h" #include "network/peloton_server.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ diff --git a/test/network/simple_query_test.cpp b/test/network/simple_query_test.cpp index 94b1a735bd2..97a2eb374bf 100644 --- a/test/network/simple_query_test.cpp +++ b/test/network/simple_query_test.cpp @@ -16,7 +16,7 @@ #include "network/peloton_server.h" #include "util/string_util.h" #include /* libpqxx is used to instantiate C++ client */ -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" #define NUM_THREADS 1 diff --git a/test/network/ssl_test.cpp b/test/network/ssl_test.cpp index 033b641494a..aeee2e61579 100644 --- a/test/network/ssl_test.cpp +++ b/test/network/ssl_test.cpp @@ -14,7 +14,7 @@ #include "common/harness.h" #include "common/logger.h" #include "gtest/gtest.h" -#include "network/network_io_wrapper_factory.h" +#include "network/connection_handle_factory.h" #include "network/peloton_server.h" #include "peloton_config.h" #include "util/string_util.h"