@@ -4311,6 +4311,133 @@ (define-public ideep-pytorch
PyTorch.")
(license license:expat)))
+(define %python-pytorch-version "2.2.1")
+
+(define %python-pytorch-src
+ (origin
+ (method git-fetch)
+ (uri (git-reference
+ (url "https://github.com/pytorch/pytorch")
+ (commit (string-append "v" %python-pytorch-version))))
+ (file-name (git-file-name "python-pytorch" %python-pytorch-version))
+ (sha256
+ (base32
+ "03mm0pwwb5lxdsmmiw3cch9fijgjw81kmmc4ln9rlyazkm7l1r48"))
+ (modules '((guix build utils)))
+ (snippet
+ '(begin
+ ;; Bundled or unused code
+ (for-each
+ (lambda (dir)
+ (when (file-exists? dir)
+ (delete-file-recursively dir)))
+ '("android"
+ "aten/src/ATen/native/cuda/cutlass_extensions"
+ "aten/src/ATen/native/quantized/cpu/qnnpack"
+ "caffe2/mobile/contrib/libopencl-stub"
+ "caffe2/mobile/contrib/libvulkan-stub"
+ "third_party"))
+
+ ;; Autogenerated files
+ (for-each
+ delete-file
+ '("aten/src/ATen/nnapi/nnapi_wrapper.cpp"
+ "aten/src/ATen/nnapi/nnapi_wrapper.h"
+ "caffe2/mobile/contrib/ios/mpscnn/mpscnn_kernels.h"
+ "caffe2/proto/caffe2_legacy_pb2.pyi"
+ "caffe2/proto/caffe2_pb2.pyi"
+ "caffe2/proto/hsm_pb2.pyi"
+ "caffe2/proto/metanet_pb2.pyi"
+ "caffe2/proto/predictor_consts_pb2.pyi"
+ "caffe2/proto/prof_dag_pb2.pyi"
+ "caffe2/proto/torch_pb2.pyi"
+ ;; These files contain just lists of floating point values and
+ ;; might be as well hand-written.
+ ;; "test/cpp/api/init_baseline.h"
+ ;; "test/cpp/api/optim_baseline.h"
+ "test/mobile/test_upgrader_bytecode_table_example.cpp"
+ "torch/csrc/jit/mobile/upgrader_mobile.cpp"
+ "torch/csrc/jit/runtime/decomposition_registry_util.cpp"
+ "torch/csrc/jit/runtime/serialized_shape_function_registry.cpp"
+ "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp"
+ "torch/csrc/jit/serialization/mobile_bytecode_generated.h"))
+ (delete-file-recursively ".github")
+ (for-each
+ (lambda (dir)
+ (for-each
+ delete-file
+ (find-files dir "\\.cu$")))
+ '("aten/src/ATen/native/transformers/cuda/flash_attn/kernels"
+ "aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels"))))))
+
+(define-public qnnpack-pytorch
+ (package
+ (inherit qnnpack)
+ (name "qnnpack-pytorch")
+ (version (string-append "pytorch-" %python-pytorch-version))
+ (source
+ (origin
+ (inherit %python-pytorch-src)
+ (patches '())
+ (modules '((guix build utils)
+ (srfi srfi-26)
+ (ice-9 ftw)))
+ (snippet
+ '(begin
+ (rename-file "aten/src/ATen/native/quantized/cpu/qnnpack"
+ "../qnnpack")
+ (let ((outdir (getcwd)))
+ (chdir "..")
+ (rename-file outdir "dummy")
+ (rename-file "qnnpack" outdir)
+ (chdir outdir)
+ (delete-file-recursively "deps"))))))
+ (arguments
+ (substitute-keyword-arguments (package-arguments qnnpack)
+ ((#:phases phases #~%standard-phases)
+ #~(modify-phases %standard-phases
+ (add-after 'unpack 'patch-cmake
+ (lambda _
+ (substitute* "CMakeLists.txt"
+ (("project\\(.*" orig)
+ (apply
+ string-append
+ orig "\n"
+ (map (lambda (name)
+ (string-append
+ "option(" name " \"\" ON)\n"))
+ '("USE_SYSTEM_CPUINFO" "USE_SYSTEM_FP16" "USE_SYSTEM_FXDIV"
+ "USE_SYSTEM_PSIMD" "USE_SYSTEM_PTHREADPOOL"))))
+ (("if.*SOURCE_DIR.*")
+ "if(FALSE)\n")
+ (("if\\(NOT TARGET (clog|gtest|benchmark).*")
+ "if(FALSE)\n")
+ (("target_link_libraries.*(fxdiv|psimd|fp16)\\).*")
+ "")
+ (("(target_link_libraries.*) fp16 (.*)" _ before after)
+ (string-append before " " after)))))
+ (add-after 'unpack 'fix-cstring-include
+ (lambda _
+ (substitute* "include/pack_block_sparse.h"
+ (("#include.*<vector>.*" orig)
+ (string-append orig "\n#include <cstring>\n")))))
+ (add-after 'install 'install-missing-headers
+ (lambda _
+ (for-each
+ (lambda (name)
+ (install-file (string-append "../source/include/" name)
+ (string-append #$output "/include")))
+ '("pack_block_sparse.h"
+ "pytorch_qnnpack.h"
+ "qnnpack_func.h"))
+ (copy-recursively
+ "../source/src/qnnpack"
+ (string-append #$output "/include/qnnpack"))))))
+ ;; Some tests occasionally fail on i686 due to floating point rounding.
+ ((#:tests? _ #t)
+ (not (string-prefix? "i686" (or (%current-target-system)
+ (%current-system)))))))))
+
;; Please also update python-torchvision when updating this package.
(define-public python-pytorch
(package