Skip to content

Commit 8a22718

Browse files
committed
update mnnsr: detect tilesize & scale from model
1 parent 56a3b84 commit 8a22718

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

RealSR-NCNN-Android-CLI/MNN-SR/src/main/jni/main.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <clocale>
77
#include <thread>
88
#include <filesystem>
9+
#include <regex>
910

1011
//#undef min
1112
//#undef max
@@ -797,6 +798,37 @@ int main(int argc, char **argv)
797798
return -1;
798799
}
799800

801+
if (scale == 0) {
802+
#if _WIN32
803+
using string_t = std::wstring;
804+
using regex_t = std::wregex;
805+
using smatch_t = std::wsmatch;
806+
#define STR(x) L##x
807+
#else
808+
using string_t = std::string;
809+
using regex_t = std::regex;
810+
using smatch_t = std::smatch;
811+
#define STR(x) x
812+
#endif
813+
// 获取文件名
814+
string_t filename = std::filesystem::path(model).filename().native();
815+
816+
//regex_t re1(STR("(.+-|^)[xX]([0-9]+(\\.[0-9]+)?).*"));
817+
//regex_t re2(STR("(.+-|^)([0-9]+(\\.[0-9]+)?)[xX].*"));
818+
regex_t re1(STR("(.+-|^)[xX]([0-9]+).*"));
819+
regex_t re2(STR("(.+-|^)([0-9]+)[xX].*"));
820+
821+
smatch_t match;
822+
if (std::regex_search(filename, match, re1) && match.size() > 1) {
823+
//scale = std::stod(match[2]);
824+
scale = std::stoi(match[2]);
825+
}
826+
else if (std::regex_search(filename, match, re2) && match.size() > 1) {
827+
//scale = std::stod(match[2]);
828+
scale = std::stoi(match[2]);
829+
}
830+
}
831+
800832

801833
#include <cstdio>
802834

RealSR-NCNN-Android-CLI/MNN-SR/src/main/jni/mnnsr.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@ int MNNSR::load(const std::string &modelpath, bool cachemodel,const bool nchw)
148148

149149

150150
interpreter_input = interpreter->getSessionInput(session, nullptr);
151+
auto dims = interpreter_input->shape();
152+
if (dims.size() != 4) {
153+
fprintf(stderr, "model input tensor shape error, expect 4 dims, but got %zu\n", dims.size());
154+
return -1;
155+
}
156+
else if (dims[2] > 0 && dims[3] > 0 && dims[2] == dims[3]) {
157+
if (dims[2] != tilesize) {
158+
fprintf(stderr, "fix tilesize %d -> %d\n", tilesize, dims[2]);
159+
tilesize = dims[2];
160+
}
161+
}
162+
151163
// fprintf(stderr, "model input tensor(b/c/h/w): %d/%d/%d/%d -> 1/%d/%d/%d\n"
152164
// , input_tensor->batch(), input_tensor->channel(), input_tensor->height(), input_tensor->width()
153165
// ,model_channel, tilesize, tilesize

0 commit comments

Comments
 (0)