Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ add_executable(katago
neuralnet/modelversion.cpp
neuralnet/nneval.cpp
neuralnet/desc.cpp
neuralnet/onnxprotoreader.cpp
${NEURALNET_BACKEND_SOURCES}
book/book.cpp
book/bookcssjs.cpp
Expand Down Expand Up @@ -397,6 +398,8 @@ elseif(USE_BACKEND STREQUAL "TENSORRT")
endif()
include_directories(SYSTEM ${CUDAToolkit_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR}) #SYSTEM is for suppressing some compiler warnings in thrust libraries
target_link_libraries(katago CUDA::cudart_static ${TENSORRT_LIBRARY})
find_library(TENSORRT_ONNX_LIBRARY nvonnxparser HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib)
target_link_libraries(katago ${TENSORRT_ONNX_LIBRARY})
elseif(USE_BACKEND STREQUAL "METAL")
target_compile_definitions(katago PRIVATE USE_METAL_BACKEND)
target_link_libraries(katago KataGoSwift)
Expand Down
19 changes: 18 additions & 1 deletion cpp/neuralnet/desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,24 @@ void ModelDesc::loadFromFileMaybeGZipped(const string& fileName, ModelDesc& desc
throw StringError("Error loading or parsing model file " + fileName + ": " + e.what());
}
}

void ModelDesc::loadFromONNX(const string& onnxFile, ModelDesc& descBuf) {
descBuf.onnxHeader.load(onnxFile);

descBuf.modelVersion = descBuf.onnxHeader.modelVersion;
descBuf.name = descBuf.onnxHeader.modelName;
descBuf.numInputChannels = descBuf.onnxHeader.num_spatial_inputs;
descBuf.numInputGlobalChannels = descBuf.onnxHeader.num_global_inputs;
descBuf.numInputMetaChannels = 0; // not supported
if(descBuf.numInputChannels != NNModelVersion::getNumSpatialFeatures(descBuf.modelVersion))
throw StringError("ONNX model requires num_spatial_inputs metadata field to match modelVersion");
if(descBuf.numInputGlobalChannels != NNModelVersion::getNumGlobalFeatures(descBuf.modelVersion))
throw StringError("ONNX model requires num_global_inputs metadata field to match modelVersion");

descBuf.numPolicyChannels = 0; // will not be used
descBuf.numValueChannels = 0; // will not be used
descBuf.numScoreValueChannels = 0;
descBuf.numOwnershipChannels = 0;
}

Rules ModelDesc::getSupportedRules(const Rules& desiredRules, bool& supported) const {
Rules rules = desiredRules;
Expand Down
5 changes: 5 additions & 0 deletions cpp/neuralnet/desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "../game/rules.h"
#include "../neuralnet/activations.h"
#include "../neuralnet/onnxprotoreader.h"

struct ConvLayerDesc {
std::string name;
Expand Down Expand Up @@ -357,6 +358,9 @@ struct ModelDesc {

int metaEncoderVersion;

//std::map<std::string, std::string> onnxMetadata; //only non-empty when loading from ONNX
ONNXModelHeader onnxHeader;

ModelPostProcessParams postProcessParams;

TrunkDesc trunk;
Expand All @@ -383,6 +387,7 @@ struct ModelDesc {
//Loads a model from a file that may or may not be gzipped, storing it in descBuf
//If expectedSha256 is nonempty, will also verify sha256 of the loaded data.
static void loadFromFileMaybeGZipped(const std::string& fileName, ModelDesc& descBuf, const std::string& expectedSha256);
static void loadFromONNX(const std::string& onnxFile, ModelDesc& descBuf);

//Return the "nearest" supported ruleset to desiredRules by this model.
//Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made.
Expand Down
8 changes: 8 additions & 0 deletions cpp/neuralnet/nneval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ NNEvaluator::NNEvaluator(
gpuIdxs.erase(last,gpuIdxs.end());
loadedModel = NeuralNet::loadModelFile(modelFileName,expectedSha256);
const ModelDesc& desc = NeuralNet::getModelDesc(loadedModel);
if(desc.onnxHeader.isOnnx)
{
desc.onnxHeader.maybeChangeNNLen(*this);
if(nnXLen > NNPos::MAX_BOARD_LEN)
throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN));
if(nnYLen > NNPos::MAX_BOARD_LEN)
throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN));
}
internalModelName = desc.name;
modelVersion = desc.modelVersion;
inputsVersion = NNModelVersion::getInputsVersion(modelVersion);
Expand Down
12 changes: 6 additions & 6 deletions cpp/neuralnet/nneval.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct NNServerBuf {
NNServerBuf(const NNServerBuf& other) = delete;
NNServerBuf& operator=(const NNServerBuf& other) = delete;
};

class ONNXModelHeader;
class NNEvaluator {
public:
NNEvaluator(
Expand Down Expand Up @@ -210,10 +210,10 @@ class NNEvaluator {
private:
const std::string modelName;
const std::string modelFileName;
const int nnXLen;
const int nnYLen;
const bool requireExactNNLen;
const int policySize;
int nnXLen;
int nnYLen;
bool requireExactNNLen;
int policySize;
const bool inputsUseNHWC;
const enabled_t usingFP16Mode;
const enabled_t usingNHWCMode;
Expand Down Expand Up @@ -268,7 +268,7 @@ class NNEvaluator {

//Queued up requests
ThreadSafeQueue<NNResultBuf*> queryQueue;

friend class ONNXModelHeader;
public:
//Helper, for internal use only
void serve(NNServerBuf& buf, Rand& rand, int gpuIdxForThisThread, int serverThreadIdx);
Expand Down
208 changes: 208 additions & 0 deletions cpp/neuralnet/onnxprotoreader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#include "onnxprotoreader.h"
#include "modelversion.h"
#include "nneval.h"
#include "../core/sha2.h"
using namespace std;
// Minimal protobuf parser helper

ProtoReader::ProtoReader(const uint8_t* p, size_t len) : ptr(p), end(p + len) {}

bool ProtoReader::hasBytes() const { return ptr < end; }

uint32_t ProtoReader::readVarint() {
uint32_t result = 0;
int shift = 0;
while(ptr < end) {
uint8_t byte = *ptr++;
result |= (uint32_t)(byte & 0x7F) << shift;
if(!(byte & 0x80))
return result;
shift += 7;
if(shift >= 32)
break; // Overflow protection
}
return result;
}

// Returns true if tag read, false if EOF
bool ProtoReader::readTag(uint32_t& fieldNum, uint32_t& wireType) {
if(ptr >= end)
return false;
uint32_t tag = readVarint();
fieldNum = tag >> 3;
wireType = tag & 7;
return true;
}

void ProtoReader::skipField(uint32_t wireType) {
if(wireType == 0) { // Varint
readVarint();
} else if(wireType == 1) { // 64-bit
if(ptr + 8 <= end)
ptr += 8;
} else if(wireType == 2) { // Length delimited
uint32_t len = readVarint();
if(ptr + len <= end)
ptr += len;
else
ptr = end;
} else if(wireType == 5) { // 32-bit
if(ptr + 4 <= end)
ptr += 4;
}
// Groups (3,4) deprecated/not supported here
}

string ProtoReader::readString() {
uint32_t len = readVarint();
if(ptr + len > end)
return "";
string s((const char*)ptr, len);
ptr += len;
return s;
}

ONNXModelHeader::ONNXModelHeader() {
clear();
}

void ONNXModelHeader::clear() {
isOnnx = false;
allmetadata.clear();
modelVersion = 0;
modelName = "";
num_spatial_inputs = 0;
num_global_inputs = 0;
has_mask = false;
pos_len_x = 0;
pos_len_y = 0;
model_config = "";
model_config_sha256 = "";
}
void ONNXModelHeader::load(const std::string& onnxFile) {
assert(Global::isSuffix(onnxFile, ".onnx"));
clear();
isOnnx = true;
// Read entire file into memory
ifstream in(onnxFile, ios::binary | ios::ate);
if(!in)
throw StringError("Could not open ONNX file: " + onnxFile);
size_t fileSize = in.tellg();
in.seekg(0, ios::beg);

vector<uint8_t> buffer(fileSize);
if(!in.read((char*)buffer.data(), fileSize))
throw StringError("Failed to read ONNX file: " + onnxFile);

ProtoReader reader(buffer.data(), fileSize);
// std::map<string, string> metadata;
allmetadata.clear();

uint32_t fieldNum, wireType;
while(reader.readTag(fieldNum, wireType)) {
if(fieldNum == 14 && wireType == 2) { // metadata_props (repeated)
// Read nested message length
uint32_t msgLen = reader.readVarint();
const uint8_t* msgEnd = reader.ptr + msgLen;
if(msgEnd > reader.end)
break;

// Parse StringStringEntryProto
ProtoReader entryReader(reader.ptr, msgLen);
reader.ptr += msgLen; // Advance main reader

string key, value;
uint32_t eField, eWire;
while(entryReader.readTag(eField, eWire)) {
if(eField == 1 && eWire == 2)
key = entryReader.readString();
else if(eField == 2 && eWire == 2)
value = entryReader.readString();
else
entryReader.skipField(eWire);
}
if(!key.empty())
allmetadata[key] = value;
} else {
reader.skipField(wireType);
}
}
if(!allmetadata.count("modelVersion"))
throw StringError("ONNX model requires a modelVersion metadata field");
else if(!Global::tryStringToInt(allmetadata["modelVersion"], modelVersion))
throw StringError(
"ONNX model requires a valid modelVersion metadata field, but got: " + allmetadata["modelVersion"]);

if(!allmetadata.count("name"))
throw StringError("ONNX model requires a name metadata field");
modelName = allmetadata["name"];

if(!allmetadata.count("num_spatial_inputs"))
throw StringError("ONNX model requires a num_spatial_inputs metadata field");
else if(!Global::tryStringToInt(allmetadata["num_spatial_inputs"], num_spatial_inputs))
throw StringError(
"ONNX model requires a valid num_spatial_inputs metadata field, but got: " + allmetadata["num_spatial_inputs"]);
if(num_spatial_inputs != NNModelVersion::getNumSpatialFeatures(modelVersion))
throw StringError("ONNX model requires num_spatial_inputs metadata field to match modelVersion");

if(!allmetadata.count("num_global_inputs"))
throw StringError("ONNX model requires a num_global_inputs metadata field");
else if(!Global::tryStringToInt(allmetadata["num_global_inputs"], num_global_inputs))
throw StringError(
"ONNX model requires a valid num_global_inputs metadata field, but got: " + allmetadata["num_global_inputs"]);
if(num_global_inputs != NNModelVersion::getNumGlobalFeatures(modelVersion))
throw StringError("ONNX model requires num_global_inputs metadata field to match modelVersion");

if(!allmetadata.count("has_mask"))
throw StringError("ONNX model requires a has_mask metadata field");
else if(!Global::tryStringToBool(allmetadata["has_mask"], has_mask))
throw StringError("ONNX model requires a valid has_mask metadata field, but got: " + allmetadata["has_mask"]);
if(!allmetadata.count("model_config") || allmetadata["model_config"].empty())
throw StringError("ONNX model requires a model_config metadata field");

if(!allmetadata.count("pos_len_x"))
throw StringError("ONNX model requires a pos_len_x metadata field");
else if(!Global::tryStringToInt(allmetadata["pos_len_x"], pos_len_x))
throw StringError("ONNX model requires a valid pos_len_x metadata field, but got: " + allmetadata["pos_len_x"]);
if(!allmetadata.count("pos_len_y"))
throw StringError("ONNX model requires a pos_len_y metadata field");
else if(!Global::tryStringToInt(allmetadata["pos_len_y"], pos_len_y))
throw StringError("ONNX model requires a valid pos_len_y metadata field, but got: " + allmetadata["pos_len_y"]);

if(!allmetadata.count("model_config") || allmetadata["model_config"].empty())
throw StringError("ONNX model requires a model_config metadata field");
model_config = allmetadata["model_config"];

{
char hashResultBuf[65];
SHA2::get256((const uint8_t*)model_config.data(), model_config.size(), hashResultBuf);
string hashResult(hashResultBuf);
model_config_sha256 = hashResult;
}
}






void ONNXModelHeader::maybeChangeNNLen(NNEvaluator& nneval) const {
if(!isOnnx)
return; // not onnx, do nothing

if(!has_mask) {
if(nneval.nnXLen != pos_len_x || nneval.nnYLen != pos_len_y || !nneval.requireExactNNLen)
throw StringError(
"ONNX model requires pos_len_x and pos_len_y metadata fields to match nnXLen and nnYLen if has_mask is false");
} else {
nneval.requireExactNNLen = false;
if(nneval.nnXLen > pos_len_x || nneval.nnYLen > pos_len_y)
throw StringError(
"ONNX model requires pos_len_x and pos_len_y metadata fields to be at least as large as nnXLen and nnYLen if "
"has_mask is true");
}
nneval.nnXLen = pos_len_x;
nneval.nnYLen = pos_len_y;
nneval.policySize = NNPos::getPolicySize(nneval.nnXLen, nneval.nnYLen);
}

49 changes: 49 additions & 0 deletions cpp/neuralnet/onnxprotoreader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef ONNX_PROTO_READER_H
#define ONNX_PROTO_READER_H

#include <string>
#include <map>
#include "../core/global.h"
#include "../core/logger.h"

// Minimal protobuf parser helper
struct ProtoReader {
const uint8_t* ptr;
const uint8_t* end;

ProtoReader(const uint8_t* p, size_t len) ;

bool hasBytes() const ;

uint32_t readVarint();

// Returns true if tag read, false if EOF
bool readTag(uint32_t& fieldNum, uint32_t& wireType) ;

void skipField(uint32_t wireType) ;

std::string readString() ;
};
class ModelDesc;
class NNEvaluator;
//static void loadModelDescFromONNX(const std::string& onnxFile, ModelDesc& desc) ;

struct ONNXModelHeader {
bool isOnnx;//True if the model is in ONNX format, false if in .bin.gz format and all fields are default values
std::map<std::string, std::string> allmetadata;
int modelVersion;
std::string modelName;
int num_spatial_inputs;
int num_global_inputs;
bool has_mask;
int pos_len_x;
int pos_len_y;
std::string model_config;
std::string model_config_sha256;
ONNXModelHeader();
void clear();
void load(const std::string& onnxFile);
void maybeChangeNNLen(NNEvaluator& nneval) const;
};

#endif // ONNX_PROTO_READER_H_
Loading