diff mbox series

[bug#59607,2/8] gnu: Add real-esrgan-ncnn.

Message ID 3823ee69ed8ec08192f0c2ebc26fe8b3d399b381.camel@gmail.com
State New
Headers show
Series Upscale your anime pictures, now with 99% less malware | expand

Commit Message

Liliana Marie Prikler Nov. 20, 2022, 1:16 a.m. UTC
* gnu/packages/machine-learning.scm (real-esrgan-ncnn): New variable.
---
 gnu/packages/machine-learning.scm             |  44 ++++
 ...real-resgan-ncnn-simplify-model-path.patch | 195 ++++++++++++++++++
 2 files changed, 239 insertions(+)
 create mode 100644 gnu/packages/patches/real-resgan-ncnn-simplify-model-path.patch
diff mbox series

Patch

diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index e984e3004b..0566f4bd69 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -781,6 +781,50 @@  (define-public ncnn
 C++.  It supports parallel computing as well as GPU acceleration via Vulkan.")
     (license license:bsd-3)))
 
+(define-public real-esrgan-ncnn
+  (package
+    (name "real-esrgan-ncnn")
+    (version "0.2.0")
+    (source (origin
+              (method git-fetch)
+              (uri (git-reference
+                    (url "https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan")
+                    (commit (string-append "v" version))))
+              (file-name (git-file-name name version))
+              (patches
+               (search-patches
+                "real-resgan-ncnn-simplify-model-path.patch"))
+              (sha256
+               (base32 "1hlrq8b4848vgj2shcxz68d98p9wd5mf619v5d04pwg40s85zqqp"))))
+    (build-system cmake-build-system)
+    (arguments
+     (list #:tests? #f                  ; No tests
+           #:configure-flags
+           #~(list "-DUSE_SYSTEM_NCNN=TRUE"
+                   "-DUSE_SYSTEM_WEBP=TRUE"
+                   (string-append "-DGLSLANG_TARGET_DIR="
+                                  #$(this-package-input "glslang")
+                                  "/lib/cmake"))
+           #:phases #~(modify-phases %standard-phases
+                        (add-after 'unpack 'chdir
+                          (lambda _
+                            (chdir "src")))
+                        (replace 'install
+                          (lambda* (#:key outputs #:allow-other-keys)
+                            (let ((bin (string-append (assoc-ref outputs "out")
+                                                      "/bin")))
+                              (mkdir-p bin)
+                              (install-file "realesrgan-ncnn-vulkan" bin)))))))
+    (inputs (list glslang libwebp ncnn vulkan-headers vulkan-loader))
+    (home-page "https://github.com/xinntao/Real-ESRGAN")
+    (synopsis "Restore low-resolution images")
+    (description "Real-ESRGAN is a @acronym{GAN, Generative Adversarial Network}
+aiming to restore low-resolution images.  The techniques used are described in
+the paper 'Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure
+Synthetic Data' by Xintao Wang, Liangbin Xie, Chao Dong, and Ying Shan.
+This package provides an implementation built on top of ncnn.")
+    (license license:expat)))
+
 (define-public onnx
   (package
     (name "onnx")
diff --git a/gnu/packages/patches/real-resgan-ncnn-simplify-model-path.patch b/gnu/packages/patches/real-resgan-ncnn-simplify-model-path.patch
new file mode 100644
index 0000000000..9a02269718
--- /dev/null
+++ b/gnu/packages/patches/real-resgan-ncnn-simplify-model-path.patch
@@ -0,0 +1,195 @@ 
+diff --git a/src/main.cpp b/src/main.cpp
+index ebe0e62..ddfb742 100644
+--- a/src/main.cpp
++++ b/src/main.cpp
+@@ -109,8 +109,7 @@ static void print_usage()
+     fprintf(stderr, "  -o output-path       output image path (jpg/png/webp) or directory\n");
+     fprintf(stderr, "  -s scale             upscale ratio (can be 2, 3, 4. default=4)\n");
+     fprintf(stderr, "  -t tile-size         tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu\n");
+-    fprintf(stderr, "  -m model-path        folder path to the pre-trained models. default=models\n");
+-    fprintf(stderr, "  -n model-name        model name (default=realesr-animevideov3, can be realesr-animevideov3 | realesrgan-x4plus | realesrgan-x4plus-anime | realesrnet-x4plus)\n");
++    fprintf(stderr, "  -m model-path        model and parameter file name (sans .bin and .param extension)\n");
+     fprintf(stderr, "  -g gpu-id            gpu device to use (default=auto) can be 0,1,2 for multi-gpu\n");
+     fprintf(stderr, "  -j load:proc:save    thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu\n");
+     fprintf(stderr, "  -x                   enable tta mode\n");
+@@ -438,8 +437,7 @@ int main(int argc, char** argv)
+     path_t outputpath;
+     int scale = 4;
+     std::vector<int> tilesize;
+-    path_t model = PATHSTR("models");
+-    path_t modelname = PATHSTR("realesr-animevideov3");
++    path_t model = PATHSTR("");
+     std::vector<int> gpuid;
+     int jobs_load = 1;
+     std::vector<int> jobs_proc;
+@@ -451,7 +449,7 @@ int main(int argc, char** argv)
+ #if _WIN32
+     setlocale(LC_ALL, "");
+     wchar_t opt;
+-    while ((opt = getopt(argc, argv, L"i:o:s:t:m:n:g:j:f:vxh")) != (wchar_t)-1)
++    while ((opt = getopt(argc, argv, L"i:o:t:m:g:j:f:vxh")) != (wchar_t)-1)
+     {
+         switch (opt)
+         {
+@@ -461,18 +459,12 @@ int main(int argc, char** argv)
+         case L'o':
+             outputpath = optarg;
+             break;
+-        case L's':
+-            scale = _wtoi(optarg);
+-            break;
+         case L't':
+             tilesize = parse_optarg_int_array(optarg);
+             break;
+         case L'm':
+             model = optarg;
+             break;
+-        case L'n':
+-            modelname = optarg;
+-            break;
+         case L'g':
+             gpuid = parse_optarg_int_array(optarg);
+             break;
+@@ -497,7 +489,7 @@ int main(int argc, char** argv)
+     }
+ #else // _WIN32
+     int opt;
+-    while ((opt = getopt(argc, argv, "i:o:s:t:m:n:g:j:f:vxh")) != -1)
++    while ((opt = getopt(argc, argv, "i:o:t:m:g:j:f:vxh")) != -1)
+     {
+         switch (opt)
+         {
+@@ -507,18 +499,12 @@ int main(int argc, char** argv)
+         case 'o':
+             outputpath = optarg;
+             break;
+-        case 's':
+-            scale = atoi(optarg);
+-            break;
+         case 't':
+             tilesize = parse_optarg_int_array(optarg);
+             break;
+         case 'm':
+             model = optarg;
+             break;
+-        case 'n':
+-            modelname = optarg;
+-            break;
+         case 'g':
+             gpuid = parse_optarg_int_array(optarg);
+             break;
+@@ -549,6 +535,12 @@ int main(int argc, char** argv)
+         return -1;
+     }
+ 
++    if (model.empty())
++    {
++        fprintf(stderr, "no model given\n");
++        return -1;
++    }
++
+     if (tilesize.size() != (gpuid.empty() ? 1 : gpuid.size()) && !tilesize.empty())
+     {
+         fprintf(stderr, "invalid tilesize argument\n");
+@@ -671,61 +663,17 @@ int main(int argc, char** argv)
+         }
+     }
+ 
+-    int prepadding = 0;
+-
+-    if (model.find(PATHSTR("models")) != path_t::npos
+-        || model.find(PATHSTR("models2")) != path_t::npos)
+-    {
+-        prepadding = 10;
+-    }
+-    else
+-    {
+-        fprintf(stderr, "unknown model dir type\n");
+-        return -1;
+-    }
++    int prepadding = 10;
+ 
+-    // if (modelname.find(PATHSTR("realesrgan-x4plus")) != path_t::npos
+-    //     || modelname.find(PATHSTR("realesrnet-x4plus")) != path_t::npos
+-    //     || modelname.find(PATHSTR("esrgan-x4")) != path_t::npos)
+-    // {}
+-    // else
+-    // {
+-    //     fprintf(stderr, "unknown model name\n");
+-    //     return -1;
+-    // }
+ 
+ #if _WIN32
+-    wchar_t parampath[256];
+-    wchar_t modelpath[256];
+-
+-    if (modelname == PATHSTR("realesr-animevideov3"))
+-    {
+-        swprintf(parampath, 256, L"%s/%s-x%s.param", model.c_str(), modelname.c_str(), std::to_string(scale));
+-        swprintf(modelpath, 256, L"%s/%s-x%s.bin", model.c_str(), modelname.c_str(), std::to_string(scale));
+-    }
+-    else{
+-        swprintf(parampath, 256, L"%s/%s.param", model.c_str(), modelname.c_str());
+-        swprintf(modelpath, 256, L"%s/%s.bin", model.c_str(), modelname.c_str());
+-    }
+-
++    path_t parampath = model + L".param";
++    path_t modelpath = model + L".bin";
+ #else
+-    char parampath[256];
+-    char modelpath[256];
+-
+-    if (modelname == PATHSTR("realesr-animevideov3"))
+-    {
+-        sprintf(parampath, "%s/%s-x%s.param", model.c_str(), modelname.c_str(), std::to_string(scale).c_str());
+-        sprintf(modelpath, "%s/%s-x%s.bin", model.c_str(), modelname.c_str(), std::to_string(scale).c_str());
+-    }
+-    else{
+-        sprintf(parampath, "%s/%s.param", model.c_str(), modelname.c_str());
+-        sprintf(modelpath, "%s/%s.bin", model.c_str(), modelname.c_str());
+-    }
++    path_t parampath = model + ".param";
++    path_t modelpath = model + ".bin";
+ #endif
+ 
+-    path_t paramfullpath = sanitize_filepath(parampath);
+-    path_t modelfullpath = sanitize_filepath(modelpath);
+-
+ #if _WIN32
+     CoInitializeEx(NULL, COINIT_MULTITHREADED);
+ #endif
+@@ -781,17 +729,14 @@ int main(int argc, char** argv)
+         uint32_t heap_budget = ncnn::get_gpu_device(gpuid[i])->get_heap_budget();
+ 
+         // more fine-grained tilesize policy here
+-        if (model.find(PATHSTR("models")) != path_t::npos)
+-        {
+-            if (heap_budget > 1900)
+-                tilesize[i] = 200;
+-            else if (heap_budget > 550)
+-                tilesize[i] = 100;
+-            else if (heap_budget > 190)
+-                tilesize[i] = 64;
+-            else
+-                tilesize[i] = 32;
+-        }
++        if (heap_budget > 1900)
++          tilesize[i] = 200;
++        else if (heap_budget > 550)
++          tilesize[i] = 100;
++        else if (heap_budget > 190)
++          tilesize[i] = 64;
++        else
++          tilesize[i] = 32;
+     }
+ 
+     {
+@@ -801,7 +746,7 @@ int main(int argc, char** argv)
+         {
+             realesrgan[i] = new RealESRGAN(gpuid[i], tta_mode);
+ 
+-            realesrgan[i]->load(paramfullpath, modelfullpath);
++            realesrgan[i]->load(parampath, modelpath);
+ 
+             realesrgan[i]->scale = scale;
+             realesrgan[i]->tilesize = tilesize[i];