diff mbox series

[bug#73094,10/10] gnu: Add python-safetensors.

Message ID 20240907100908.25197-10-ngraves@ngraves.fr
State New
Headers show
Series [bug#73094,01/10] gnu: rust-darling-core-0.20: Update to 0.20.8. | expand

Commit Message

Nicolas Graves Sept. 7, 2024, 10:08 a.m. UTC
* gnu/packages/machine-learning.scm (python-safetensors): New variable.

Change-Id: I90a1684d06756ce87ca0862d745a75be5919f0b2
---
 gnu/packages/machine-learning.scm | 97 +++++++++++++++++++++++++++++++
 1 file changed, 97 insertions(+)
diff mbox series

Patch

diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index a4aeb97be7..12be1d7bf6 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -1120,6 +1120,103 @@  (define-public rust-safetensors
 @code{PyTorch} counterparts.")
     (license license:asl2.0)))
 
+(define-public python-safetensors
+  (package
+    (name "python-safetensors")
+    (version "0.4.3")
+    (source
+     (origin
+       (method url-fetch)
+       (uri (pypi-uri "safetensors" version))
+       (sha256
+        (base32 "1hhiwy67jarm70l0k26fs1cjhzkgzrh79q14bklj2yp0qi8gr19g"))
+       (modules '((guix build utils)
+                  (ice-9 ftw)))
+       (snippet
+        #~(begin  ;; Only keeping bindings.
+            (for-each (lambda (file)
+                        (unless (member file '("." ".." "bindings" "PKG-INFO"))
+                          (delete-file-recursively file)))
+                      (scandir "."))
+            (for-each (lambda (file)
+                        (unless (member file '("." ".."))
+                          (rename-file (string-append "bindings/python/" file) file)))
+                      (scandir "bindings/python"))))))
+    (build-system cargo-build-system)
+    (arguments
+     (list
+      #:imported-modules `(,@%cargo-build-system-modules
+                           ,@%pyproject-build-system-modules)
+      #:modules '((guix build cargo-build-system)
+                  ((guix build pyproject-build-system) #:prefix py:)
+                  (guix build utils)
+                  (ice-9 regex)
+                  (ice-9 textual-ports))
+      #:phases
+      #~(modify-phases %standard-phases
+          (add-after 'unpack-rust-crates 'inject-safetensors
+            (lambda _
+              (substitute* "Cargo.toml"
+                (("\\[dependencies\\]")
+                 (format #f "[dependencies]~%safetensors = ~s"
+                         #$(package-version rust-safetensors))))
+              (let ((file-path "Cargo.toml"))
+                (call-with-input-file file-path
+                  (lambda (port)
+                    (let* ((content (get-string-all port))
+                           (top-match (string-match
+                                       "\\[dependencies.safetensors" content)))
+                      (call-with-output-file file-path
+                        (lambda (out)
+                          (format out "~a" (match:prefix top-match))))))))))
+          (replace 'check
+            (lambda _
+              (copy-file "target/release/libsafetensors_rust.so"
+                         "py_src/safetensors/_safetensors_rust.so")
+              (invoke "python3"
+                      "-c" (format #f
+                                   "import sys; sys.path.append(\"~a/py_src\")"
+                                   (getcwd))
+                      "-m" "pytest"
+                      "-n" "auto"
+                      "--dist=loadfile"
+                      "-s" "-v" "./tests/"
+                      "--ignore=./tests/test_flax_comparison.py")))
+          (add-after 'install 'install-python
+            (lambda _
+              (let* ((pversion #$(version-major+minor (package-version python)))
+                     (lib (string-append #$output "/lib/python" pversion
+                                         "/site-packages/"))
+                     (info (string-append lib "safetensors-"
+                                        #$(package-version this-package)
+                                        ".dist-info")))
+                (mkdir-p info)
+                (copy-file "PKG-INFO" (string-append info "/METADATA"))
+                (copy-recursively
+                 "py_src/safetensors"
+                 (string-append lib "safetensors"))))))
+      #:cargo-inputs
+      `(("rust-pyo3" ,rust-pyo3-0.21)
+        ("rust-memmap2" ,rust-memmap2-0.9)
+        ("rust-safetensors" ,rust-safetensors)
+        ("rust-serde-json" ,rust-serde-json-1))))
+    (inputs
+     (list rust-safetensors))
+    (native-inputs
+     (list python-h5py
+           python-minimal
+           python-numpy
+           python-pytest
+           python-pytest-xdist
+           python-pytorch
+           tensorflow))
+    (home-page "https://huggingface.co/docs/safetensors")
+    (synopsis "Simple and safe way to store and distribute tensors")
+    (description "This package provides a fast (zero-copy) and safe
+(dedicated) format for storing tensors safely.  This package builds upon
+@code{rust-safetensors} and provides Python bindings.")
+    (license license:asl2.0)))
+
 (define-public python-sentencepiece
   (package
     (name "python-sentencepiece")