diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 254d23233..a08763d89 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -244,6 +244,7 @@ add_executable(katago neuralnet/modelversion.cpp neuralnet/nneval.cpp neuralnet/desc.cpp + neuralnet/onnxprotoreader.cpp ${NEURALNET_BACKEND_SOURCES} book/book.cpp book/bookcssjs.cpp @@ -397,6 +398,8 @@ elseif(USE_BACKEND STREQUAL "TENSORRT") endif() include_directories(SYSTEM ${CUDAToolkit_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR}) #SYSTEM is for suppressing some compiler warnings in thrust libraries target_link_libraries(katago CUDA::cudart_static ${TENSORRT_LIBRARY}) + find_library(TENSORRT_ONNX_LIBRARY nvonnxparser HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib) + target_link_libraries(katago ${TENSORRT_ONNX_LIBRARY}) elseif(USE_BACKEND STREQUAL "METAL") target_compile_definitions(katago PRIVATE USE_METAL_BACKEND) target_link_libraries(katago KataGoSwift) diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 5ae003cec..480075687 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -1851,7 +1851,24 @@ void ModelDesc::loadFromFileMaybeGZipped(const string& fileName, ModelDesc& desc throw StringError("Error loading or parsing model file " + fileName + ": " + e.what()); } } - +void ModelDesc::loadFromONNX(const string& onnxFile, ModelDesc& descBuf) { + descBuf.onnxHeader.load(onnxFile); + + descBuf.modelVersion = descBuf.onnxHeader.modelVersion; + descBuf.name = descBuf.onnxHeader.modelName; + descBuf.numInputChannels = descBuf.onnxHeader.num_spatial_inputs; + descBuf.numInputGlobalChannels = descBuf.onnxHeader.num_global_inputs; + descBuf.numInputMetaChannels = 0; // not supported + if(descBuf.numInputChannels != NNModelVersion::getNumSpatialFeatures(descBuf.modelVersion)) + throw StringError("ONNX model requires num_spatial_inputs metadata field to match modelVersion"); + if(descBuf.numInputGlobalChannels != NNModelVersion::getNumGlobalFeatures(descBuf.modelVersion)) + throw StringError("ONNX model requires num_global_inputs metadata field to match modelVersion"); + + descBuf.numPolicyChannels = 0; // will not be used + descBuf.numValueChannels = 0; // will not be used + descBuf.numScoreValueChannels = 0; + descBuf.numOwnershipChannels = 0; +} Rules ModelDesc::getSupportedRules(const Rules& desiredRules, bool& supported) const { Rules rules = desiredRules; diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index b44d00c59..728601ffa 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -10,6 +10,7 @@ #include "../game/rules.h" #include "../neuralnet/activations.h" +#include "../neuralnet/onnxprotoreader.h" struct ConvLayerDesc { std::string name; @@ -357,6 +358,9 @@ struct ModelDesc { int metaEncoderVersion; + //std::map onnxMetadata; //only non-empty when loading from ONNX + ONNXModelHeader onnxHeader; + ModelPostProcessParams postProcessParams; TrunkDesc trunk; @@ -383,6 +387,7 @@ struct ModelDesc { //Loads a model from a file that may or may not be gzipped, storing it in descBuf //If expectedSha256 is nonempty, will also verify sha256 of the loaded data. static void loadFromFileMaybeGZipped(const std::string& fileName, ModelDesc& descBuf, const std::string& expectedSha256); + static void loadFromONNX(const std::string& onnxFile, ModelDesc& descBuf); //Return the "nearest" supported ruleset to desiredRules by this model. //Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made. diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index 595fe78dc..67c606ee7 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -135,6 +135,14 @@ NNEvaluator::NNEvaluator( gpuIdxs.erase(last,gpuIdxs.end()); loadedModel = NeuralNet::loadModelFile(modelFileName,expectedSha256); const ModelDesc& desc = NeuralNet::getModelDesc(loadedModel); + if(desc.onnxHeader.isOnnx) + { + desc.onnxHeader.maybeChangeNNLen(*this); + if(nnXLen > NNPos::MAX_BOARD_LEN) + throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); + if(nnYLen > NNPos::MAX_BOARD_LEN) + throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); + } internalModelName = desc.name; modelVersion = desc.modelVersion; inputsVersion = NNModelVersion::getInputsVersion(modelVersion); diff --git a/cpp/neuralnet/nneval.h b/cpp/neuralnet/nneval.h index 04ff7506b..99d8f94db 100644 --- a/cpp/neuralnet/nneval.h +++ b/cpp/neuralnet/nneval.h @@ -75,7 +75,7 @@ struct NNServerBuf { NNServerBuf(const NNServerBuf& other) = delete; NNServerBuf& operator=(const NNServerBuf& other) = delete; }; - +class ONNXModelHeader; class NNEvaluator { public: NNEvaluator( @@ -210,10 +210,10 @@ class NNEvaluator { private: const std::string modelName; const std::string modelFileName; - const int nnXLen; - const int nnYLen; - const bool requireExactNNLen; - const int policySize; + int nnXLen; + int nnYLen; + bool requireExactNNLen; + int policySize; const bool inputsUseNHWC; const enabled_t usingFP16Mode; const enabled_t usingNHWCMode; @@ -268,7 +268,7 @@ class NNEvaluator { //Queued up requests ThreadSafeQueue queryQueue; - + friend class ONNXModelHeader; public: //Helper, for internal use only void serve(NNServerBuf& buf, Rand& rand, int gpuIdxForThisThread, int serverThreadIdx); diff --git a/cpp/neuralnet/onnxprotoreader.cpp b/cpp/neuralnet/onnxprotoreader.cpp new file mode 100644 index 000000000..44604e05c --- /dev/null +++ b/cpp/neuralnet/onnxprotoreader.cpp @@ -0,0 +1,208 @@ +#include "onnxprotoreader.h" +#include "modelversion.h" +#include "nneval.h" +#include "../core/sha2.h" +using namespace std; +// Minimal protobuf parser helper + +ProtoReader::ProtoReader(const uint8_t* p, size_t len) : ptr(p), end(p + len) {} + +bool ProtoReader::hasBytes() const { return ptr < end; } + +uint32_t ProtoReader::readVarint() { + uint32_t result = 0; + int shift = 0; + while(ptr < end) { + uint8_t byte = *ptr++; + result |= (uint32_t)(byte & 0x7F) << shift; + if(!(byte & 0x80)) + return result; + shift += 7; + if(shift >= 32) + break; // Overflow protection + } + return result; +} + +// Returns true if tag read, false if EOF +bool ProtoReader::readTag(uint32_t& fieldNum, uint32_t& wireType) { + if(ptr >= end) + return false; + uint32_t tag = readVarint(); + fieldNum = tag >> 3; + wireType = tag & 7; + return true; +} + +void ProtoReader::skipField(uint32_t wireType) { + if(wireType == 0) { // Varint + readVarint(); + } else if(wireType == 1) { // 64-bit + if(ptr + 8 <= end) + ptr += 8; + } else if(wireType == 2) { // Length delimited + uint32_t len = readVarint(); + if(ptr + len <= end) + ptr += len; + else + ptr = end; + } else if(wireType == 5) { // 32-bit + if(ptr + 4 <= end) + ptr += 4; + } + // Groups (3,4) deprecated/not supported here +} + +string ProtoReader::readString() { + uint32_t len = readVarint(); + if(ptr + len > end) + return ""; + string s((const char*)ptr, len); + ptr += len; + return s; +} + +ONNXModelHeader::ONNXModelHeader() { + clear(); +} + +void ONNXModelHeader::clear() { + isOnnx = false; + allmetadata.clear(); + modelVersion = 0; + modelName = ""; + num_spatial_inputs = 0; + num_global_inputs = 0; + has_mask = false; + pos_len_x = 0; + pos_len_y = 0; + model_config = ""; + model_config_sha256 = ""; +} +void ONNXModelHeader::load(const std::string& onnxFile) { + assert(Global::isSuffix(onnxFile, ".onnx")); + clear(); + isOnnx = true; + // Read entire file into memory + ifstream in(onnxFile, ios::binary | ios::ate); + if(!in) + throw StringError("Could not open ONNX file: " + onnxFile); + size_t fileSize = in.tellg(); + in.seekg(0, ios::beg); + + vector buffer(fileSize); + if(!in.read((char*)buffer.data(), fileSize)) + throw StringError("Failed to read ONNX file: " + onnxFile); + + ProtoReader reader(buffer.data(), fileSize); + // std::map metadata; + allmetadata.clear(); + + uint32_t fieldNum, wireType; + while(reader.readTag(fieldNum, wireType)) { + if(fieldNum == 14 && wireType == 2) { // metadata_props (repeated) + // Read nested message length + uint32_t msgLen = reader.readVarint(); + const uint8_t* msgEnd = reader.ptr + msgLen; + if(msgEnd > reader.end) + break; + + // Parse StringStringEntryProto + ProtoReader entryReader(reader.ptr, msgLen); + reader.ptr += msgLen; // Advance main reader + + string key, value; + uint32_t eField, eWire; + while(entryReader.readTag(eField, eWire)) { + if(eField == 1 && eWire == 2) + key = entryReader.readString(); + else if(eField == 2 && eWire == 2) + value = entryReader.readString(); + else + entryReader.skipField(eWire); + } + if(!key.empty()) + allmetadata[key] = value; + } else { + reader.skipField(wireType); + } + } + if(!allmetadata.count("modelVersion")) + throw StringError("ONNX model requires a modelVersion metadata field"); + else if(!Global::tryStringToInt(allmetadata["modelVersion"], modelVersion)) + throw StringError( + "ONNX model requires a valid modelVersion metadata field, but got: " + allmetadata["modelVersion"]); + + if(!allmetadata.count("name")) + throw StringError("ONNX model requires a name metadata field"); + modelName = allmetadata["name"]; + + if(!allmetadata.count("num_spatial_inputs")) + throw StringError("ONNX model requires a num_spatial_inputs metadata field"); + else if(!Global::tryStringToInt(allmetadata["num_spatial_inputs"], num_spatial_inputs)) + throw StringError( + "ONNX model requires a valid num_spatial_inputs metadata field, but got: " + allmetadata["num_spatial_inputs"]); + if(num_spatial_inputs != NNModelVersion::getNumSpatialFeatures(modelVersion)) + throw StringError("ONNX model requires num_spatial_inputs metadata field to match modelVersion"); + + if(!allmetadata.count("num_global_inputs")) + throw StringError("ONNX model requires a num_global_inputs metadata field"); + else if(!Global::tryStringToInt(allmetadata["num_global_inputs"], num_global_inputs)) + throw StringError( + "ONNX model requires a valid num_global_inputs metadata field, but got: " + allmetadata["num_global_inputs"]); + if(num_global_inputs != NNModelVersion::getNumGlobalFeatures(modelVersion)) + throw StringError("ONNX model requires num_global_inputs metadata field to match modelVersion"); + + if(!allmetadata.count("has_mask")) + throw StringError("ONNX model requires a has_mask metadata field"); + else if(!Global::tryStringToBool(allmetadata["has_mask"], has_mask)) + throw StringError("ONNX model requires a valid has_mask metadata field, but got: " + allmetadata["has_mask"]); + if(!allmetadata.count("model_config") || allmetadata["model_config"].empty()) + throw StringError("ONNX model requires a model_config metadata field"); + + if(!allmetadata.count("pos_len_x")) + throw StringError("ONNX model requires a pos_len_x metadata field"); + else if(!Global::tryStringToInt(allmetadata["pos_len_x"], pos_len_x)) + throw StringError("ONNX model requires a valid pos_len_x metadata field, but got: " + allmetadata["pos_len_x"]); + if(!allmetadata.count("pos_len_y")) + throw StringError("ONNX model requires a pos_len_y metadata field"); + else if(!Global::tryStringToInt(allmetadata["pos_len_y"], pos_len_y)) + throw StringError("ONNX model requires a valid pos_len_y metadata field, but got: " + allmetadata["pos_len_y"]); + + if(!allmetadata.count("model_config") || allmetadata["model_config"].empty()) + throw StringError("ONNX model requires a model_config metadata field"); + model_config = allmetadata["model_config"]; + + { + char hashResultBuf[65]; + SHA2::get256((const uint8_t*)model_config.data(), model_config.size(), hashResultBuf); + string hashResult(hashResultBuf); + model_config_sha256 = hashResult; + } +} + + + + + + +void ONNXModelHeader::maybeChangeNNLen(NNEvaluator& nneval) const { + if(!isOnnx) + return; // not onnx, do nothing + + if(!has_mask) { + if(nneval.nnXLen != pos_len_x || nneval.nnYLen != pos_len_y || !nneval.requireExactNNLen) + throw StringError( + "ONNX model requires pos_len_x and pos_len_y metadata fields to match nnXLen and nnYLen if has_mask is false"); + } else { + nneval.requireExactNNLen = false; + if(nneval.nnXLen > pos_len_x || nneval.nnYLen > pos_len_y) + throw StringError( + "ONNX model requires pos_len_x and pos_len_y metadata fields to be at least as large as nnXLen and nnYLen if " + "has_mask is true"); + } + nneval.nnXLen = pos_len_x; + nneval.nnYLen = pos_len_y; + nneval.policySize = NNPos::getPolicySize(nneval.nnXLen, nneval.nnYLen); +} + diff --git a/cpp/neuralnet/onnxprotoreader.h b/cpp/neuralnet/onnxprotoreader.h new file mode 100644 index 000000000..d686f73f6 --- /dev/null +++ b/cpp/neuralnet/onnxprotoreader.h @@ -0,0 +1,49 @@ +#ifndef ONNX_PROTO_READER_H +#define ONNX_PROTO_READER_H + +#include +#include +#include "../core/global.h" +#include "../core/logger.h" + +// Minimal protobuf parser helper +struct ProtoReader { + const uint8_t* ptr; + const uint8_t* end; + + ProtoReader(const uint8_t* p, size_t len) ; + + bool hasBytes() const ; + + uint32_t readVarint(); + + // Returns true if tag read, false if EOF + bool readTag(uint32_t& fieldNum, uint32_t& wireType) ; + + void skipField(uint32_t wireType) ; + + std::string readString() ; +}; +class ModelDesc; +class NNEvaluator; +//static void loadModelDescFromONNX(const std::string& onnxFile, ModelDesc& desc) ; + +struct ONNXModelHeader { + bool isOnnx;//True if the model is in ONNX format, false if in .bin.gz format and all fields are default values + std::map allmetadata; + int modelVersion; + std::string modelName; + int num_spatial_inputs; + int num_global_inputs; + bool has_mask; + int pos_len_x; + int pos_len_y; + std::string model_config; + std::string model_config_sha256; + ONNXModelHeader(); + void clear(); + void load(const std::string& onnxFile); + void maybeChangeNNLen(NNEvaluator& nneval) const; +}; + +#endif // ONNX_PROTO_READER_H_ \ No newline at end of file diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index c6df9f251..224a589ee 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -3,6 +3,8 @@ #define CUDA_API_PER_THREAD_DEFAULT_STREAM #include #include +#include "NvOnnxConfig.h" +#include "NvOnnxParser.h" #include "../core/fileutils.h" #include "../core/makedir.h" @@ -21,6 +23,10 @@ using namespace nvinfer1; // Define this to print out some of the intermediate values of the neural net //#define DEBUG_INTERMEDIATE_VALUES +//#define CACHE_TENSORRT_PLAN + +const int TensorRT_BuilderOptimizationLevel = 2; //0 for fast init, 2 is default, 5 is max + static void checkCudaError(const cudaError_t status, const char* opName, const char* file, const char* func, int line) { if(status != cudaSuccess) throw StringError( @@ -43,6 +49,43 @@ struct ComputeContext { int nnYLen; enabled_t useFP16Mode; string homeDataDirOverride; + string onnxModelPath; + bool isOnnx; +}; + + + +void NeuralNet::freeComputeContext(ComputeContext* computeContext) { + delete computeContext; +} + + + +struct LoadedModel { + ModelDesc modelDesc; + string fileName; + bool isOnnx; + + LoadedModel(const string& fileName, const string& expectedSha256) { + this->fileName = fileName; + if (Global::isSuffix(fileName, ".onnx")) { + isOnnx = true; + try { + ModelDesc::loadFromONNX(fileName, modelDesc); + + } catch (const StringError& e) { + throw StringError("Failed to load ONNX model config: " + fileName + "\n" + e.what()); + } + } else { + isOnnx = false; + ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); + modelDesc.applyScale8ToReduceActivations(); + } + } + + LoadedModel() = delete; + LoadedModel(const LoadedModel&) = delete; + LoadedModel& operator=(const LoadedModel&) = delete; }; ComputeContext* NeuralNet::createComputeContext( @@ -60,7 +103,6 @@ ComputeContext* NeuralNet::createComputeContext( (void)logger; (void)openCLTunerFile; (void)openCLReTunePerBoardSize; - (void)loadedModel; if(useNHWCMode == enabled_t::True) { throw StringError("TensorRT backend: useNHWC = false required, other configurations not supported"); @@ -71,26 +113,11 @@ ComputeContext* NeuralNet::createComputeContext( context->nnYLen = nnYLen; context->useFP16Mode = useFP16Mode; context->homeDataDirOverride = homeDataDirOverride; + context->isOnnx = loadedModel->isOnnx; + context->onnxModelPath = loadedModel->fileName; return context; } -void NeuralNet::freeComputeContext(ComputeContext* computeContext) { - delete computeContext; -} - -struct LoadedModel { - ModelDesc modelDesc; - - LoadedModel(const string& fileName, const string& expectedSha256) { - ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); - modelDesc.applyScale8ToReduceActivations(); - } - - LoadedModel() = delete; - LoadedModel(const LoadedModel&) = delete; - LoadedModel& operator=(const LoadedModel&) = delete; -}; - LoadedModel* NeuralNet::loadModelFile(const string& file, const string& expectedSha256) { LoadedModel* loadedModel = new LoadedModel(file, expectedSha256); return loadedModel; @@ -146,7 +173,7 @@ struct ModelParser { ModelParser& operator=(const ModelParser&) = delete; // Bump this when between katago versions we want to forcibly drop old timing caches and plan caches. - static constexpr int tuneSalt = 7; + static constexpr int tuneSalt = 9; unique_ptr build( unique_ptr net, @@ -1140,11 +1167,53 @@ struct ComputeHandle { if(!profile) { throw StringError("TensorRT backend: failed to create optimization profile"); } - auto modelParser = make_unique(); - auto model = modelParser->build( - move(network), profile, loadedModel, ctx->nnXLen, ctx->nnYLen, maxBatchSize, requireExactNNLen); - debugOutputs = model->debugOutputs; - config->addOptimizationProfile(profile); + + unique_ptr model; + if (ctx->isOnnx) { + //check whether the pos_len matches + { + if( + loadedModel->modelDesc.onnxHeader.pos_len_y != ctx->nnYLen || + loadedModel->modelDesc.onnxHeader.pos_len_x != ctx->nnXLen) { + throw StringError( + "TensorRT backend: pos_len in model desc does not match nnYLen or nnXLen, " + "pos_len_y from model = " + + Global::intToString(loadedModel->modelDesc.onnxHeader.pos_len_y) + + "pos_len_x from model = " + Global::intToString(loadedModel->modelDesc.onnxHeader.pos_len_x) + + ", nnYLen=" + Global::intToString(ctx->nnYLen) + ", nnXLen=" + Global::intToString(ctx->nnXLen)); + } + + if((!requireExactNNLen) && (!loadedModel->modelDesc.onnxHeader.has_mask)) { + throw StringError("TensorRT backend: model does not have mask, but requireExactNNLen is false"); + } + + } + + auto parser = nvonnxparser::createParser(*network, trtLogger); + if(!parser) { + throw StringError("TensorRT backend: failed to create ONNX parser"); + } + if(!parser->parseFromFile(ctx->onnxModelPath.c_str(), static_cast(ILogger::Severity::kERROR))) { + throw StringError("TensorRT backend: failed to parse ONNX model"); + } + + int64_t spatialC = NNModelVersion::getNumSpatialFeatures(modelVersion); + int64_t globalC = NNModelVersion::getNumGlobalFeatures(modelVersion); + profile->setDimensions("input_spatial", OptProfileSelector::kMIN, Dims4(1, spatialC, ctx->nnYLen, ctx->nnXLen)); + profile->setDimensions("input_spatial", OptProfileSelector::kOPT, Dims4(maxBatchSize, spatialC, ctx->nnYLen, ctx->nnXLen)); + profile->setDimensions("input_spatial", OptProfileSelector::kMAX, Dims4(maxBatchSize, spatialC, ctx->nnYLen, ctx->nnXLen)); + profile->setDimensions("input_global", OptProfileSelector::kMIN, Dims2(1, globalC)); + profile->setDimensions("input_global", OptProfileSelector::kOPT, Dims2(maxBatchSize, globalC)); + profile->setDimensions("input_global", OptProfileSelector::kMAX, Dims2(maxBatchSize, globalC)); + config->addOptimizationProfile(profile); + } + else { + auto modelParser = make_unique(); + model = modelParser->build( + move(network), profile, loadedModel, ctx->nnXLen, ctx->nnYLen, maxBatchSize, requireExactNNLen); + debugOutputs = model->debugOutputs; + config->addOptimizationProfile(profile); + } #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 // This is to avoid external tactic sources and tactics that have shape switching overhead @@ -1159,16 +1228,16 @@ struct ComputeHandle { if(prop->major >= 8) { // This is to avoid tactics that have shape switching overhead config->setTacticSources(1U << static_cast(TacticSource::kJIT_CONVOLUTIONS)); - config->setBuilderOptimizationLevel(2); } #endif + config->setBuilderOptimizationLevel(TensorRT_BuilderOptimizationLevel); // So that there are no concurrent kernel executions probably from other parts of code while profiling // See CUDA Runtime API document for more details related to NULL stream and synchronization behaviors config->setProfileStream(cudaStreamLegacy); - - // Typical runtime allocation is much less than the 1 GiB specified below - config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1U << 30); + + // Typical runtime allocation is much less than the 4 GiB specified below + config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1ULL << 32); string plan; { @@ -1190,28 +1259,70 @@ struct ComputeHandle { deviceIdent[sizeof(deviceIdent) - 1] = 0; #ifdef CACHE_TENSORRT_PLAN - auto planCacheFile = Global::strprintf( - "%s/trt-%d_gpu-%s_net-%s_%d_%s%dx%d_batch%d_fp%d", - cacheDir.c_str(), - getInferLibVersion(), - deviceIdent, - loadedModel->modelDesc.name.c_str(), - ModelParser::tuneSalt, - requireExactNNLen ? "exact" : "max", - ctx->nnYLen, - ctx->nnXLen, - maxBatchSize, - usingFP16 ? 16 : 32); - string paramStr = Global::strprintf( - "_%d_%s_%d_%s_%d_%d_%d_%d", - getInferLibVersion(), - deviceIdent, - ModelParser::tuneSalt, - requireExactNNLen ? "exact" : "max", - ctx->nnYLen, - ctx->nnXLen, - maxBatchSize, - usingFP16 ? 16 : 32); + string modelHashStr; + if (ctx->isOnnx) { + string tmp; + FileUtils::loadFileIntoString(ctx->onnxModelPath, "", tmp, &modelHashStr); + } else { + modelHashStr = loadedModel->modelDesc.sha256; + } + + string planCacheFile = ""; + string paramStr = ""; + + if(ctx->isOnnx) { + + + planCacheFile = Global::strprintf( + "%s/trt-onnx-%d_olv-%d_gpu-%s_net-%s_%d_%s%dx%d_batch%d_fp%d", + cacheDir.c_str(), + getInferLibVersion(), + TensorRT_BuilderOptimizationLevel, + deviceIdent, + modelHashStr.substr(0, 12).c_str(), + ModelParser::tuneSalt, + (!loadedModel->modelDesc.onnxHeader.has_mask) ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); + string paramStr = Global::strprintf( + "_%d_%s_%d_%s_%d_%d_%d_%d", + getInferLibVersion(), + deviceIdent, + ModelParser::tuneSalt, + (!loadedModel->modelDesc.onnxHeader.has_mask) ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); + } + else { + planCacheFile = Global::strprintf( + "%s/trt-%d_olv-%d_gpu-%s_net-%s_%d_%s%dx%d_batch%d_fp%d", + cacheDir.c_str(), + getInferLibVersion(), + TensorRT_BuilderOptimizationLevel, + deviceIdent, + loadedModel->modelDesc.name.c_str(), + ModelParser::tuneSalt, + requireExactNNLen ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); + string paramStr = Global::strprintf( + "_%d_%s_%d_%s_%d_%d_%d_%d", + getInferLibVersion(), + deviceIdent, + ModelParser::tuneSalt, + requireExactNNLen ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); + } + try { plan = FileUtils::readFileBinary(planCacheFile); } catch(const StringError& e) { @@ -1225,7 +1336,7 @@ struct ComputeHandle { } else { string cachedParamStr = plan.substr(plan.size() - paramStr.size()); string modelHash = plan.substr(plan.size() - 64 - paramStr.size(), 64); - if(modelHash != loadedModel->modelDesc.sha256) { + if(modelHash != modelHashStr) { logger->write("Plan cache is corrupted or is for the wrong model in " + planCacheFile); plan.clear(); } else if(cachedParamStr != paramStr) { @@ -1239,7 +1350,9 @@ struct ComputeHandle { if(plan.size() <= 0) { logger->write("Creating new plan cache"); - auto planBuffer = unique_ptr(builder->buildSerializedNetwork(*model->network, *config)); + // network is moved into model if !isOnnx, but for isOnnx network is still valid (not moved) + INetworkDefinition* netPtr = ctx->isOnnx ? network.get() : model->network.get(); + auto planBuffer = unique_ptr(builder->buildSerializedNetwork(*netPtr, *config)); if(!planBuffer) { throw StringError("TensorRT backend: failed to create plan"); } @@ -1247,10 +1360,10 @@ struct ComputeHandle { plan.end(), static_cast(planBuffer->data()), static_cast(planBuffer->data()) + planBuffer->size()); - if(loadedModel->modelDesc.sha256.size() != 64) { + if(modelHashStr.size() != 64) { throw StringError("Unexpected model hash size"); } - plan.insert(plan.end(), loadedModel->modelDesc.sha256.begin(), loadedModel->modelDesc.sha256.end()); + plan.insert(plan.end(), modelHashStr.begin(), modelHashStr.end()); plan.insert(plan.end(), paramStr.begin(), paramStr.end()); ofstream ofs; FileUtils::open(ofs, planCacheFile, ios::out | ios::binary); @@ -1264,24 +1377,44 @@ struct ComputeHandle { logger->write("Using existing plan cache at " + planCacheFile); } #else - // Truncated to 6 bytes - char tuneIdent[6 * 2 + 1]; - for(int i = 0; i < 6; i++) { - sprintf(tuneIdent + i * 2, "%02x", static_cast(model->tuneHash[i])); + string timingCacheFile = ""; + + if (ctx->isOnnx) { + + timingCacheFile = Global::strprintf( + "%s/trt-onnx-%d_gpu-%s_mc-%s_ts-%d_%s%dx%d_batch%d_fp%d", + cacheDir.c_str(), + getInferLibVersion(), + deviceIdent, + loadedModel->modelDesc.onnxHeader.model_config_sha256.substr(0, 12).c_str(), + ModelParser::tuneSalt, + (!loadedModel->modelDesc.onnxHeader.has_mask) ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); + + } else { + + // Truncated to 6 bytes + char tuneIdent[6 * 2 + 1]; + for(int i = 0; i < 6; i++) { + sprintf(tuneIdent + i * 2, "%02x", static_cast(model->tuneHash[i])); + } + tuneIdent[sizeof(tuneIdent) - 1] = 0; + timingCacheFile = Global::strprintf( + "%s/trt-%d_gpu-%s_tune-%s_%s%dx%d_batch%d_fp%d", + cacheDir.c_str(), + getInferLibVersion(), + deviceIdent, + tuneIdent, + requireExactNNLen ? "exact" : "max", + ctx->nnYLen, + ctx->nnXLen, + maxBatchSize, + usingFP16 ? 16 : 32); } - tuneIdent[sizeof(tuneIdent) - 1] = 0; - - auto timingCacheFile = Global::strprintf( - "%s/trt-%d_gpu-%s_tune-%s_%s%dx%d_batch%d_fp%d", - cacheDir.c_str(), - getInferLibVersion(), - deviceIdent, - tuneIdent, - requireExactNNLen ? "exact" : "max", - ctx->nnYLen, - ctx->nnXLen, - maxBatchSize, - usingFP16 ? 16 : 32); + string timingCacheBlob; try { @@ -1305,7 +1438,8 @@ struct ComputeHandle { unique_ptr planBuffer; if(invalidTimingCache || !timingCacheBlob.size()) { - planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); + INetworkDefinition* netPtr = ctx->isOnnx ? network.get() : model->network.get(); + planBuffer.reset(builder->buildSerializedNetwork(*netPtr, *config)); if(!planBuffer) { throw StringError("TensorRT backend: failed to create plan"); } @@ -1318,7 +1452,8 @@ struct ComputeHandle { tuneMutex.unlock(); } else { tuneMutex.unlock(); - planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); + INetworkDefinition* netPtr = ctx->isOnnx ? network.get() : model->network.get(); + planBuffer.reset(builder->buildSerializedNetwork(*netPtr, *config)); if(!planBuffer) { throw StringError("TensorRT backend: failed to create plan"); } @@ -1516,6 +1651,7 @@ void NeuralNet::printDevices() { struct InputBuffers { int maxBatchSize; + bool isOnnx; size_t singleMaskElts; size_t singleMaskBytes; @@ -1536,6 +1672,18 @@ struct InputBuffers { size_t singleOwnershipResultElts; size_t singleOwnershipResultBytes; + // ONNX specific + size_t singleout_policyElts; + size_t singleout_policyBytes; + size_t singleout_valueElts; + size_t singleout_valueBytes; + size_t singleout_miscvalueElts; + size_t singleout_miscvalueBytes; + size_t singleout_moremiscvalueElts; + size_t singleout_moremiscvalueBytes; + size_t singleout_ownershipElts; + size_t singleout_ownershipBytes; + size_t inputMaskBufferBytes; size_t inputSpatialBufferBytes; size_t inputGlobalBufferBytes; @@ -1546,6 +1694,12 @@ struct InputBuffers { size_t scoreValueResultBufferBytes; size_t ownershipResultBufferBytes; + size_t out_policyBufferBytes; + size_t out_valueBufferBytes; + size_t out_miscvalueBufferBytes; + size_t out_moremiscvalueBufferBytes; + size_t out_ownershipBufferBytes; + unique_ptr maskInputs; // Host pointer unique_ptr spatialInputs; // Host pointer unique_ptr globalInputs; // Host pointer @@ -1556,8 +1710,15 @@ struct InputBuffers { unique_ptr scoreValueResults; // Host pointer unique_ptr ownershipResults; // Host pointer + unique_ptr out_policyResults; + unique_ptr out_valueResults; + unique_ptr out_miscvalueResults; + unique_ptr out_moremiscvalueResults; + unique_ptr out_ownershipResults; + InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen) { const ModelDesc& m = loadedModel->modelDesc; + isOnnx = loadedModel->isOnnx; if(nnXLen > NNPos::MAX_BOARD_LEN) throw StringError( @@ -1586,6 +1747,38 @@ struct InputBuffers { singleOwnershipResultElts = m.numOwnershipChannels * nnXLen * nnYLen; singleOwnershipResultBytes = singleOwnershipResultElts * sizeof(float); + if (isOnnx) { + int policyNum = (m.modelVersion >= 12 && m.modelVersion <= 99) ? 6 : 4; + if(m.modelVersion != 11 && m.modelVersion != 12 && m.modelVersion != 13 && m.modelVersion != 14 && m.modelVersion != 15 + && m.modelVersion != 102) + { + std::cout << "version: " << m.modelVersion << " is not supported in ONNX" << std::endl; + assert(false); + } + singleout_policyElts = 1 * policyNum * (nnXLen * nnYLen + 1); + singleout_policyBytes = singleout_policyElts * sizeof(float); + singleout_valueElts = 3; + singleout_valueBytes = singleout_valueElts * sizeof(float); + singleout_miscvalueElts = 10; + singleout_miscvalueBytes = singleout_miscvalueElts * sizeof(float); + singleout_moremiscvalueElts = 8; + singleout_moremiscvalueBytes = singleout_moremiscvalueElts * sizeof(float); + singleout_ownershipElts = 1 * nnXLen * nnYLen; + singleout_ownershipBytes = singleout_ownershipElts * sizeof(float); + + out_policyBufferBytes = maxBatchSize * singleout_policyBytes; + out_valueBufferBytes = maxBatchSize * singleout_valueBytes; + out_miscvalueBufferBytes = maxBatchSize * singleout_miscvalueBytes; + out_moremiscvalueBufferBytes = maxBatchSize * singleout_moremiscvalueBytes; + out_ownershipBufferBytes = maxBatchSize * singleout_ownershipBytes; + + out_policyResults = std::make_unique(maxBatchSize * singleout_policyElts); + out_valueResults = std::make_unique(maxBatchSize * singleout_valueElts); + out_miscvalueResults = std::make_unique(maxBatchSize * singleout_miscvalueElts); + out_moremiscvalueResults = std::make_unique(maxBatchSize * singleout_moremiscvalueElts); + out_ownershipResults = std::make_unique(maxBatchSize * singleout_ownershipElts); + } + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); if(m.numInputMetaChannels > 0) { @@ -1639,6 +1832,7 @@ void NeuralNet::getOutput( const int nnXLen = gpuHandle->ctx->nnXLen; const int nnYLen = gpuHandle->ctx->nnYLen; const int modelVersion = gpuHandle->modelVersion; + bool isOnnx = gpuHandle->ctx->isOnnx; const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); @@ -1671,113 +1865,62 @@ void NeuralNet::getOutput( copy(rowSpatialInput, rowSpatialInput + inputBuffers->singleMaskElts, rowMaskInput); } - assert(inputBuffers->singleMaskElts == gpuHandle->getBufferRowElts("InputMask")); - assert(inputBuffers->singleInputElts == gpuHandle->getBufferRowElts("InputSpatial")); - assert(inputBuffers->singleInputGlobalElts == gpuHandle->getBufferRowElts("InputGlobal")); - if(numMetaFeatures > 0) - assert(inputBuffers->singleInputMetaElts == gpuHandle->getBufferRowElts("InputMeta")); - assert(inputBuffers->singlePolicyPassResultElts == gpuHandle->getBufferRowElts("OutputPolicyPass")); - assert(inputBuffers->singlePolicyResultElts == gpuHandle->getBufferRowElts("OutputPolicy")); - assert(inputBuffers->singleValueResultElts == gpuHandle->getBufferRowElts("OutputValue")); - assert(inputBuffers->singleScoreValueResultElts == gpuHandle->getBufferRowElts("OutputScoreValue")); - assert(inputBuffers->singleOwnershipResultElts == gpuHandle->getBufferRowElts("OutputOwnership")); - - assert(inputBuffers->inputMaskBufferBytes == gpuHandle->getBufferBytes("InputMask")); - assert(inputBuffers->inputSpatialBufferBytes == gpuHandle->getBufferBytes("InputSpatial")); - assert(inputBuffers->inputGlobalBufferBytes == gpuHandle->getBufferBytes("InputGlobal")); - if(numMetaFeatures > 0) - assert(inputBuffers->inputMetaBufferBytes == gpuHandle->getBufferBytes("InputMeta")); - assert(inputBuffers->policyPassResultBufferBytes == gpuHandle->getBufferBytes("OutputPolicyPass")); - assert(inputBuffers->policyResultBufferBytes == gpuHandle->getBufferBytes("OutputPolicy")); - assert(inputBuffers->valueResultBufferBytes == gpuHandle->getBufferBytes("OutputValue")); - assert(inputBuffers->scoreValueResultBufferBytes == gpuHandle->getBufferBytes("OutputScoreValue")); - assert(inputBuffers->ownershipResultBufferBytes == gpuHandle->getBufferBytes("OutputOwnership")); - - const int numPolicyChannels = inputBuffers->singlePolicyPassResultElts; - assert(inputBuffers->singlePolicyResultElts == numPolicyChannels * nnXLen * nnYLen); - - // Transfers from host memory to device memory are asynchronous with respect to the host - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputMask"), - inputBuffers->maskInputs.get(), - inputBuffers->singleMaskBytes * batchSize, - cudaMemcpyHostToDevice)); - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputSpatial"), - inputBuffers->spatialInputs.get(), - inputBuffers->singleInputBytes * batchSize, - cudaMemcpyHostToDevice)); - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputGlobal"), - inputBuffers->globalInputs.get(), - inputBuffers->singleInputGlobalBytes * batchSize, - cudaMemcpyHostToDevice)); - if(numMetaFeatures > 0) { - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputMeta"), - inputBuffers->metaInputs.get(), - inputBuffers->singleInputMetaBytes * batchSize, - cudaMemcpyHostToDevice)); - } + // Set inputs + if (isOnnx) { + assert(inputBuffers->singleInputElts == gpuHandle->getBufferRowElts("input_spatial")); + assert(inputBuffers->singleInputGlobalElts == gpuHandle->getBufferRowElts("input_global")); + + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("input_spatial"), inputBuffers->spatialInputs.get(), inputBuffers->singleInputBytes * batchSize, cudaMemcpyHostToDevice)); + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("input_global"), inputBuffers->globalInputs.get(), inputBuffers->singleInputGlobalBytes * batchSize, cudaMemcpyHostToDevice)); + + auto spatialInputDims = gpuHandle->getBufferDynamicShape("input_spatial", batchSize); + auto globalInputDims = gpuHandle->getBufferDynamicShape("input_global", batchSize); + gpuHandle->exec->setInputShape("input_spatial", spatialInputDims); + gpuHandle->exec->setInputShape("input_global", globalInputDims); + } else { + assert(inputBuffers->singleMaskElts == gpuHandle->getBufferRowElts("InputMask")); + assert(inputBuffers->singleInputElts == gpuHandle->getBufferRowElts("InputSpatial")); + assert(inputBuffers->singleInputGlobalElts == gpuHandle->getBufferRowElts("InputGlobal")); + if(numMetaFeatures > 0) + assert(inputBuffers->singleInputMetaElts == gpuHandle->getBufferRowElts("InputMeta")); + + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("InputMask"), inputBuffers->maskInputs.get(), inputBuffers->singleMaskBytes * batchSize, cudaMemcpyHostToDevice)); + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("InputSpatial"), inputBuffers->spatialInputs.get(), inputBuffers->singleInputBytes * batchSize, cudaMemcpyHostToDevice)); + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("InputGlobal"), inputBuffers->globalInputs.get(), inputBuffers->singleInputGlobalBytes * batchSize, cudaMemcpyHostToDevice)); + if(numMetaFeatures > 0) { + CUDA_ERR("getOutput", cudaMemcpyAsync(gpuHandle->getBuffer("InputMeta"), inputBuffers->metaInputs.get(), inputBuffers->singleInputMetaBytes * batchSize, cudaMemcpyHostToDevice)); + } - auto maskInputDims = gpuHandle->getBufferDynamicShape("InputMask", batchSize); - auto spatialInputDims = gpuHandle->getBufferDynamicShape("InputSpatial", batchSize); - auto globalInputDims = gpuHandle->getBufferDynamicShape("InputGlobal", batchSize); + auto maskInputDims = gpuHandle->getBufferDynamicShape("InputMask", batchSize); + auto spatialInputDims = gpuHandle->getBufferDynamicShape("InputSpatial", batchSize); + auto globalInputDims = gpuHandle->getBufferDynamicShape("InputGlobal", batchSize); - gpuHandle->exec->setInputShape("InputMask", maskInputDims); - gpuHandle->exec->setInputShape("InputSpatial", spatialInputDims); - gpuHandle->exec->setInputShape("InputGlobal", globalInputDims); + gpuHandle->exec->setInputShape("InputMask", maskInputDims); + gpuHandle->exec->setInputShape("InputSpatial", spatialInputDims); + gpuHandle->exec->setInputShape("InputGlobal", globalInputDims); - if(numMetaFeatures > 0) { - auto metaInputDims = gpuHandle->getBufferDynamicShape("InputMeta", batchSize); - gpuHandle->exec->setInputShape("InputMeta", metaInputDims); + if(numMetaFeatures > 0) { + auto metaInputDims = gpuHandle->getBufferDynamicShape("InputMeta", batchSize); + gpuHandle->exec->setInputShape("InputMeta", metaInputDims); + } } gpuHandle->exec->enqueueV3(cudaStreamPerThread); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->policyPassResults.get(), - gpuHandle->getBuffer("OutputPolicyPass"), - inputBuffers->singlePolicyPassResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->policyResults.get(), - gpuHandle->getBuffer("OutputPolicy"), - inputBuffers->singlePolicyResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->valueResults.get(), - gpuHandle->getBuffer("OutputValue"), - inputBuffers->singleValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->scoreValueResults.get(), - gpuHandle->getBuffer("OutputScoreValue"), - inputBuffers->singleScoreValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->ownershipResults.get(), - gpuHandle->getBuffer("OutputOwnership"), - inputBuffers->singleOwnershipResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + // Get outputs + if (isOnnx) { + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->out_policyResults.get(), gpuHandle->getBuffer("out_policy"), inputBuffers->singleout_policyBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->out_valueResults.get(), gpuHandle->getBuffer("out_value"), inputBuffers->singleout_valueBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->out_miscvalueResults.get(), gpuHandle->getBuffer("out_miscvalue"), inputBuffers->singleout_miscvalueBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->out_moremiscvalueResults.get(), gpuHandle->getBuffer("out_moremiscvalue"), inputBuffers->singleout_moremiscvalueBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->out_ownershipResults.get(), gpuHandle->getBuffer("out_ownership"), inputBuffers->singleout_ownershipBytes * batchSize, cudaMemcpyDeviceToHost)); + } else { + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->policyPassResults.get(), gpuHandle->getBuffer("OutputPolicyPass"), inputBuffers->singlePolicyPassResultBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->policyResults.get(), gpuHandle->getBuffer("OutputPolicy"), inputBuffers->singlePolicyResultBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->valueResults.get(), gpuHandle->getBuffer("OutputValue"), inputBuffers->singleValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->scoreValueResults.get(), gpuHandle->getBuffer("OutputScoreValue"), inputBuffers->singleScoreValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); + CUDA_ERR("getOutput", cudaMemcpy(inputBuffers->ownershipResults.get(), gpuHandle->getBuffer("OutputOwnership"), inputBuffers->singleOwnershipResultBytes * batchSize, cudaMemcpyDeviceToHost)); + } gpuHandle->printDebugOutput(batchSize); gpuHandle->trtErrorRecorder.clear(); @@ -1785,6 +1928,7 @@ void NeuralNet::getOutput( assert(outputs.size() == batchSize); float policyProbsTmp[NNPos::MAX_NN_POLICY_SIZE]; + const int numPolicyChannels = inputBuffers->singlePolicyPassResultElts; for(int row = 0; row < batchSize; row++) { NNOutput* output = outputs[row]; @@ -1793,82 +1937,150 @@ void NeuralNet::getOutput( assert(output->nnYLen == nnYLen); float policyOptimism = (float)inputBufs[row]->policyOptimism; - const float* policyPassSrcBuf = &inputBuffers->policyPassResults[row * inputBuffers->singlePolicyPassResultElts]; - const float* policySrcBuf = &inputBuffers->policyResults[row * inputBuffers->singlePolicyResultElts]; - float* policyProbs = output->policyProbs; + if(isOnnx) { + const float* policySrcBuf = &inputBuffers->out_policyResults[row * inputBuffers->singleout_policyElts]; + float* policyProbs = output->policyProbs; + + int numPolicyChannels = (modelVersion >= 12 && modelVersion <= 99) ? 2 : 1; + + if(numPolicyChannels == 2) { + assert(inputBuffers->singleout_policyElts == 6 * (nnXLen * nnYLen + 1)); + // TRT is all NCHW + for(int i = 0; i < nnXLen * nnYLen; i++) { + float p = policySrcBuf[i]; + // float pOpt = policySrcBuf[i + nnXLen * nnYLen]; + float pOpt = policySrcBuf[i + 5 * (nnXLen * nnYLen + 1)]; + policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; + } + SymmetryHelpers::copyOutputsWithSymmetry( + policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[nnXLen * nnYLen] = + policySrcBuf[nnXLen * nnYLen] + + (policySrcBuf[5 * (nnXLen * nnYLen + 1) + nnXLen * nnYLen] - policySrcBuf[nnXLen * nnYLen]) * policyOptimism; + } else { + // Fallback or default for 1 channel (or just take first channel if > 1 and not handled) + // trtbackend1.cpp logic for 6/4 channels likely just takes the first one? + // "singleout_policyElts = 1 * policyNum * ..." + // If policyNum is 6, what do we do? + // In trtbackend1.cpp getOutput, it asserts numPolicyChannels == 1. + // If policyNum was 6, assert would fail if numPolicyChannels was 6. + // But numPolicyChannels variable in trtbackend1.cpp seemed to come from somewhere else or was 1. + // If we assume channel 0 is the main policy. + assert(numPolicyChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry( + policySrcBuf, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[nnXLen * nnYLen] = policySrcBuf[nnXLen * nnYLen]; + } + + int numValueChannels = inputBuffers->singleout_valueElts; + assert(numValueChannels == 3); + output->whiteWinProb = inputBuffers->out_valueResults[row * numValueChannels]; + output->whiteLossProb = inputBuffers->out_valueResults[row * numValueChannels + 1]; + output->whiteNoResultProb = inputBuffers->out_valueResults[row * numValueChannels + 2]; + + // As above, these are NOT actually from white's perspective, but rather the player to move. + // As usual the client does the postprocessing. + if(output->whiteOwnerMap != NULL) { + // const float* ownershipSrcBuf = &inputBuffers->ownershipResults[row * nnXLen * nnYLen]; + const float* ownershipSrcBuf = &inputBuffers->out_ownershipResults[row * nnXLen * nnYLen]; + assert(inputBuffers->singleout_ownershipElts == nnXLen * nnYLen); + SymmetryHelpers::copyOutputsWithSymmetry( + ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } + + + int numScoreValueChannels = inputBuffers->singleout_miscvalueElts; + int numMoreValueChannels = inputBuffers->singleout_moremiscvalueElts; + if(modelVersion >= 9) { + output->whiteScoreMean = inputBuffers->out_miscvalueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->out_miscvalueResults[row * numScoreValueChannels + 1]; + output->whiteLead = inputBuffers->out_miscvalueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = inputBuffers->out_miscvalueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = inputBuffers->out_moremiscvalueResults[row * numMoreValueChannels]; + output->shorttermScoreError = inputBuffers->out_moremiscvalueResults[row * numMoreValueChannels + 1]; + } + else + { + std::cout << "version: " << modelVersion << " is not supported in ONNX" << std::endl; + assert(false); + } + + } else { + const float* policyPassSrcBuf = &inputBuffers->policyPassResults[row * inputBuffers->singlePolicyPassResultElts]; + const float* policySrcBuf = &inputBuffers->policyResults[row * inputBuffers->singlePolicyResultElts]; + float* policyProbs = output->policyProbs; // These are in logits, the client does the postprocessing to turn them into // policy probabilities and white game outcome probabilities // Also we don't fill in the nnHash here either // Handle version >= 12 policy optimism - if(numPolicyChannels == 2 || (numPolicyChannels == 4 && modelVersion >= 16)) { + if(numPolicyChannels == 2 || (numPolicyChannels == 4 && modelVersion >= 16)) { // TRT is all NCHW - for(int i = 0; i < nnXLen * nnYLen; i++) { - float p = policySrcBuf[i]; - float pOpt = policySrcBuf[i + nnXLen * nnYLen]; - policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; - } - SymmetryHelpers::copyOutputsWithSymmetry( - policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + for(int i = 0; i < nnXLen * nnYLen; i++) { + float p = policySrcBuf[i]; + float pOpt = policySrcBuf[i + nnXLen * nnYLen]; + policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; + } + SymmetryHelpers::copyOutputsWithSymmetry( + policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0] + (policyPassSrcBuf[1] - policyPassSrcBuf[0]) * policyOptimism; - } else { - assert(numPolicyChannels == 1); + } else { + assert(numPolicyChannels == 1); SymmetryHelpers::copyOutputsWithSymmetry(policySrcBuf, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); - policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0]; - } + policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0]; + } - int numValueChannels = inputBuffers->singleValueResultElts; - assert(numValueChannels == 3); - output->whiteWinProb = inputBuffers->valueResults[row * numValueChannels]; - output->whiteLossProb = inputBuffers->valueResults[row * numValueChannels + 1]; - output->whiteNoResultProb = inputBuffers->valueResults[row * numValueChannels + 2]; + int numValueChannels = inputBuffers->singleValueResultElts; + assert(numValueChannels == 3); + output->whiteWinProb = inputBuffers->valueResults[row * numValueChannels]; + output->whiteLossProb = inputBuffers->valueResults[row * numValueChannels + 1]; + output->whiteNoResultProb = inputBuffers->valueResults[row * numValueChannels + 2]; // As above, these are NOT actually from white's perspective, but rather the player to move. // As usual the client does the postprocessing. - if(output->whiteOwnerMap != NULL) { - const float* ownershipSrcBuf = &inputBuffers->ownershipResults[row * nnXLen * nnYLen]; - assert(inputBuffers->singleOwnershipResultElts == nnXLen * nnYLen); - SymmetryHelpers::copyOutputsWithSymmetry( - ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); - } + if(output->whiteOwnerMap != NULL) { + const float* ownershipSrcBuf = &inputBuffers->ownershipResults[row * nnXLen * nnYLen]; + assert(inputBuffers->singleOwnershipResultElts == nnXLen * nnYLen); + SymmetryHelpers::copyOutputsWithSymmetry( + ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } - int numScoreValueChannels = inputBuffers->singleScoreValueResultElts; - if(modelVersion >= 9) { - assert(numScoreValueChannels == 6); - output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; - output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; - output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; - output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; - output->shorttermWinlossError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 4]; - output->shorttermScoreError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 5]; - } else if(modelVersion >= 8) { - assert(numScoreValueChannels == 4); - output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; - output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; - output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; - output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; - output->shorttermWinlossError = 0; - output->shorttermScoreError = 0; - } else if(modelVersion >= 4) { - assert(numScoreValueChannels == 2); - output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; - output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; - output->whiteLead = output->whiteScoreMean; - output->varTimeLeft = 0; - output->shorttermWinlossError = 0; - output->shorttermScoreError = 0; - } else if(modelVersion >= 3) { - assert(numScoreValueChannels == 1); - output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; - // Version 3 neural nets don't have any second moment output, implicitly already folding it in, so we just use the - // mean squared - output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; - output->whiteLead = output->whiteScoreMean; - output->varTimeLeft = 0; - output->shorttermWinlossError = 0; - output->shorttermScoreError = 0; - } else { - ASSERT_UNREACHABLE; + int numScoreValueChannels = inputBuffers->singleScoreValueResultElts; + if(modelVersion >= 9) { + assert(numScoreValueChannels == 6); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 4]; + output->shorttermScoreError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 5]; + } else if(modelVersion >= 8) { + assert(numScoreValueChannels == 4); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } else if(modelVersion >= 4) { + assert(numScoreValueChannels == 2); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } else if(modelVersion >= 3) { + assert(numScoreValueChannels == 1); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } else { + ASSERT_UNREACHABLE; + } } } }