diff mbox series

[bug#73094,v3,10/10] gnu: Add python-safetensors.

Message ID 20240909100804.13552-11-herman@rimm.ee
State New
Headers show
Series Add python-safetensors. | expand

Commit Message

Herman Rimm Sept. 9, 2024, 10:08 a.m. UTC
From: Nicolas Graves <ngraves@ngraves.fr>

* gnu/packages/machine-learning.scm (python-safetensors): New variable.

Change-Id: I90a1684d06756ce87ca0862d745a75be5919f0b2
---
 gnu/packages/machine-learning.scm | 100 ++++++++++++++++++++++++++++++
 1 file changed, 100 insertions(+)
diff mbox series

Patch

diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 410b71b061..8b9dd2f7e4 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -1085,6 +1085,106 @@  (define-public rust-safetensors
 @code{PyTorch} counterparts.")
     (license license:asl2.0)))
 
+(define-public python-safetensors
+  (package
+    (name "python-safetensors")
+    (version "0.4.3")
+    (source
+     (origin
+       (method url-fetch)
+       (uri (pypi-uri "safetensors" version))
+       (sha256
+        (base32 "1hhiwy67jarm70l0k26fs1cjhzkgzrh79q14bklj2yp0qi8gr19g"))
+       (modules '((guix build utils)
+                  (ice-9 ftw)))
+       (snippet
+        #~(begin  ;; Only keeping bindings.
+            (for-each
+              (lambda (file)
+                (unless (member file '("." ".." "bindings" "PKG-INFO"))
+                  (delete-file-recursively file)))
+              (scandir "."))
+            (for-each
+              (lambda (file)
+                (unless (member file '("." ".."))
+                  (rename-file (string-append "bindings/python/" file)
+                               file)))
+              (scandir "bindings/python"))))))
+    (build-system cargo-build-system)
+    (arguments
+     (list
+      #:modules '((guix build cargo-build-system)
+                  (guix build utils)
+                  (ice-9 regex)
+                  (ice-9 textual-ports)
+                  (srfi srfi-26))
+      #:phases
+      #~(modify-phases %standard-phases
+          (add-after 'unpack-rust-crates 'inject-safetensors
+            (lambda _
+              (substitute* "Cargo.toml"
+                (("\\[dependencies\\]")
+                 (format #f "[dependencies]~%safetensors = ~s"
+                         #$(package-version rust-safetensors))))
+              (call-with-input-file "Cargo.toml"
+                (lambda (port)
+                  (let* ((content (get-string-all port))
+                         (top-match (string-match
+                                      "\\[dependencies.safetensors"
+                                      content)))
+                    (call-with-output-file "Cargo.toml"
+                      (cut display (match:prefix top-match) <>)))))))
+          (add-before 'check 'install-rust-library
+            (lambda _
+              (copy-file "target/release/libsafetensors_rust.so"
+                         "py_src/safetensors/_safetensors_rust.so")))
+          (replace 'check
+            (lambda _
+              (invoke "python3"
+                      "-c" (string-append "import sys; sys.path.append"
+                                          "(\"" (getcwd) "/py_src\")")
+                      "-m" "pytest"
+                      "-n" "auto"
+                      "--dist=loadfile"
+                      "-s" "-v" "./tests/"
+                      ;; Missing jax dependency
+                      "--ignore=./tests/test_flax_comparison.py")))
+          (add-after 'install 'install-python
+            (lambda _
+              (let* ((pversion #$(version-major+minor
+                                   (package-version python)))
+                     (lib (string-append #$output "/lib/python" pversion
+                                         "/site-packages/"))
+                     (info (string-append lib "safetensors-"
+                                        #$(package-version this-package)
+                                        ".dist-info")))
+                (mkdir-p info)
+                (copy-file "PKG-INFO" (string-append info "/METADATA"))
+                (copy-recursively
+                 "py_src/safetensors"
+                 (string-append lib "safetensors"))))))
+      #:cargo-inputs
+      `(("rust-pyo3" ,rust-pyo3-0.21)
+        ("rust-memmap2" ,rust-memmap2-0.9)
+        ("rust-safetensors" ,rust-safetensors)
+        ("rust-serde-json" ,rust-serde-json-1))))
+    (inputs
+     (list rust-safetensors))
+    (native-inputs
+     (list python-h5py
+           python-minimal
+           python-numpy
+           python-pytest
+           python-pytest-xdist
+           python-pytorch
+           tensorflow))
+    (home-page "https://huggingface.co/docs/safetensors")
+    (synopsis "Simple and safe way to store and distribute tensors")
+    (description "This package provides a fast (zero-copy) and safe
+(dedicated) format for storing tensors safely.  This package builds upon
+@code{rust-safetensors} and provides Python bindings.")
+    (license license:asl2.0)))
+
 (define-public python-sentencepiece
   (package
     (name "python-sentencepiece")