diff mbox series

[bug#69591,v2,25/31] gnu: Add qnnpack-pytorch.

Message ID 20240312225211.16427-25-david.elsing@posteo.net
State New
Headers show
Series Unbundle and update python-pytorch | expand

Commit Message

David Elsing March 12, 2024, 10:51 p.m. UTC
This is an internal fork of QNNPACK in the PyTorch source tree.

* gnu/packages/machine-learning.scm (%python-pytorch-version): New variable.
(%python-pytorch-src): New variable.
(qnnpack-pytorch): New variable.
---
 gnu/packages/machine-learning.scm | 127 ++++++++++++++++++++++++++++++
 1 file changed, 127 insertions(+)
diff mbox series

Patch

diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 24422e6965..6b62d7016b 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -4311,6 +4311,133 @@  (define-public ideep-pytorch
 PyTorch.")
     (license license:expat)))
 
+(define %python-pytorch-version "2.2.1")
+
+(define %python-pytorch-src
+  (origin
+    (method git-fetch)
+    (uri (git-reference
+          (url "https://github.com/pytorch/pytorch")
+          (commit (string-append "v" %python-pytorch-version))))
+    (file-name (git-file-name "python-pytorch" %python-pytorch-version))
+    (sha256
+     (base32
+      "03mm0pwwb5lxdsmmiw3cch9fijgjw81kmmc4ln9rlyazkm7l1r48"))
+    (modules '((guix build utils)))
+    (snippet
+     '(begin
+        ;; Bundled or unused code
+        (for-each
+         (lambda (dir)
+           (when (file-exists? dir)
+             (delete-file-recursively dir)))
+         '("android"
+           "aten/src/ATen/native/cuda/cutlass_extensions"
+           "aten/src/ATen/native/quantized/cpu/qnnpack"
+           "caffe2/mobile/contrib/libopencl-stub"
+           "caffe2/mobile/contrib/libvulkan-stub"
+           "third_party"))
+
+        ;; Autogenerated files
+        (for-each
+         delete-file
+         '("aten/src/ATen/nnapi/nnapi_wrapper.cpp"
+           "aten/src/ATen/nnapi/nnapi_wrapper.h"
+           "caffe2/mobile/contrib/ios/mpscnn/mpscnn_kernels.h"
+           "caffe2/proto/caffe2_legacy_pb2.pyi"
+           "caffe2/proto/caffe2_pb2.pyi"
+           "caffe2/proto/hsm_pb2.pyi"
+           "caffe2/proto/metanet_pb2.pyi"
+           "caffe2/proto/predictor_consts_pb2.pyi"
+           "caffe2/proto/prof_dag_pb2.pyi"
+           "caffe2/proto/torch_pb2.pyi"
+           ;; These files contain just lists of floating point values and
+           ;; might be as well hand-written.
+           ;; "test/cpp/api/init_baseline.h"
+           ;; "test/cpp/api/optim_baseline.h"
+           "test/mobile/test_upgrader_bytecode_table_example.cpp"
+           "torch/csrc/jit/mobile/upgrader_mobile.cpp"
+           "torch/csrc/jit/runtime/decomposition_registry_util.cpp"
+           "torch/csrc/jit/runtime/serialized_shape_function_registry.cpp"
+           "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp"
+           "torch/csrc/jit/serialization/mobile_bytecode_generated.h"))
+        (delete-file-recursively ".github")
+        (for-each
+         (lambda (dir)
+           (for-each
+            delete-file
+            (find-files dir "\\.cu$")))
+         '("aten/src/ATen/native/transformers/cuda/flash_attn/kernels"
+           "aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels"))))))
+
+(define-public qnnpack-pytorch
+  (package
+    (inherit qnnpack)
+    (name "qnnpack-pytorch")
+    (version (string-append "pytorch-" %python-pytorch-version))
+    (source
+     (origin
+       (inherit %python-pytorch-src)
+       (patches '())
+       (modules '((guix build utils)
+                  (srfi srfi-26)
+                  (ice-9 ftw)))
+       (snippet
+        '(begin
+           (rename-file "aten/src/ATen/native/quantized/cpu/qnnpack"
+                        "../qnnpack")
+           (let ((outdir (getcwd)))
+             (chdir "..")
+             (rename-file outdir "dummy")
+             (rename-file "qnnpack" outdir)
+             (chdir outdir)
+             (delete-file-recursively "deps"))))))
+    (arguments
+     (substitute-keyword-arguments (package-arguments qnnpack)
+       ((#:phases phases #~%standard-phases)
+        #~(modify-phases %standard-phases
+            (add-after 'unpack 'patch-cmake
+              (lambda _
+                (substitute* "CMakeLists.txt"
+                  (("project\\(.*" orig)
+                   (apply
+                    string-append
+                    orig "\n"
+                    (map (lambda (name)
+                           (string-append
+                            "option(" name " \"\" ON)\n"))
+                         '("USE_SYSTEM_CPUINFO" "USE_SYSTEM_FP16" "USE_SYSTEM_FXDIV"
+                           "USE_SYSTEM_PSIMD" "USE_SYSTEM_PTHREADPOOL"))))
+                  (("if.*SOURCE_DIR.*")
+                   "if(FALSE)\n")
+                  (("if\\(NOT TARGET (clog|gtest|benchmark).*")
+                   "if(FALSE)\n")
+                  (("target_link_libraries.*(fxdiv|psimd|fp16)\\).*")
+                   "")
+                  (("(target_link_libraries.*) fp16 (.*)" _ before after)
+                   (string-append before " " after)))))
+            (add-after 'unpack 'fix-cstring-include
+              (lambda _
+                (substitute* "include/pack_block_sparse.h"
+                  (("#include.*<vector>.*" orig)
+                   (string-append orig "\n#include <cstring>\n")))))
+            (add-after 'install 'install-missing-headers
+              (lambda _
+                (for-each
+                 (lambda (name)
+                   (install-file (string-append "../source/include/" name)
+                                 (string-append #$output "/include")))
+                 '("pack_block_sparse.h"
+                   "pytorch_qnnpack.h"
+                   "qnnpack_func.h"))
+                (copy-recursively
+                 "../source/src/qnnpack"
+                 (string-append #$output "/include/qnnpack"))))))
+       ;; Some tests occasionally fail on i686 due to floating point rounding.
+       ((#:tests? _ #t)
+        (not (string-prefix? "i686" (or (%current-target-system)
+                                        (%current-system)))))))))
+
 ;; Please also update python-torchvision when updating this package.
 (define-public python-pytorch
   (package