Skip to content

Commit 61bccb0

Browse files
committed
feat: Update FP8 type format for nx/safetensors upstream changes
Update dependencies to use git versions of nx and safetensors which have the new FP8 type representation: {:f8_e4m3fn, 8} instead of {:f, 8, :e4m3fn}. Changes: - Update mix.exs to use git deps for nx, exla, torchx, and safetensors - Update FP8 type detection pattern in pytorch_params.ex - Add TODO comments noting deps should be switched back to hex when released Tested with Qwen/Qwen3-4B-Instruct-2507-FP8 model - loads and generates correctly with preserve_source_types: true.
1 parent 0c4cfc3 commit 61bccb0

3 files changed

Lines changed: 16 additions & 14 deletions

File tree

lib/bumblebee/conversion/pytorch_params.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,12 @@ defmodule Bumblebee.Conversion.PyTorchParams do
514514
Utils.Nx.map(expr, &Nx.shape/1)
515515
end
516516

517-
defp ensure_type(param_expr, value, preserve_source_types \\ false) do
517+
defp ensure_type(param_expr, value, preserve_source_types) do
518518
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
519519
case {Nx.type(expr), Nx.type(tensor), preserve_source_types} do
520520
{type, type, _} -> tensor
521-
# Preserve FP8 types when preserve_source_types is enabled
522-
{_expected, {:f, 8, _format}, true} -> tensor
521+
# Preserve FP8 E4M3FN types when preserve_source_types is enabled
522+
{_expected, {:f8_e4m3fn, 8}, true} -> tensor
523523
{expected, _actual, _} -> Nx.as_type(tensor, expected)
524524
end
525525
end)

mix.exs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@ defmodule Bumblebee.MixProject do
3434
{:axon, "~> 0.7.0"},
3535
# {:axon, github: "elixir-nx/axon", override: true},
3636
{:tokenizers, "~> 0.4"},
37-
{:nx, "~> 0.9.0 or ~> 0.10.0"},
38-
{:exla, ">= 0.0.0", only: [:dev, :test]},
39-
{:torchx, ">= 0.0.0", only: [:dev, :test]},
40-
# {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
41-
# {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]},
42-
# {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
37+
# TODO: Replace git deps with hex versions once nx >= X.X.X and safetensors >= 0.2.0 are released
38+
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
39+
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]},
40+
{:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
41+
# {:nx, "~> 0.9.0 or ~> 0.10.0"},
42+
# {:exla, ">= 0.0.0", only: [:dev, :test]},
43+
# {:torchx, ">= 0.0.0", only: [:dev, :test]},
4344
{:nx_image, "~> 0.1.0"},
4445
{:unpickler, "~> 0.1.0"},
45-
{:safetensors, "~> 0.1.3"},
46+
# TODO: Replace git dep with hex version once safetensors >= 0.2.0 is released
47+
{:safetensors, github: "elixir-nx/safetensors"},
4648
{:jason, "~> 1.4.0"},
4749
{:unzip, "~> 0.12.0 or ~> 0.13.0"},
4850
{:progress_bar, "~> 3.0"},

mix.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"},
1212
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
1313
"ex_doc": {:hex, :ex_doc, "0.39.1", "e19d356a1ba1e8f8cfc79ce1c3f83884b6abfcb79329d435d4bbb3e97ccc286e", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "8abf0ed3e3ca87c0847dfc4168ceab5bedfe881692f1b7c45f4a11b232806865"},
14-
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
14+
"exla": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "exla"]},
1515
"fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"},
1616
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
1717
"makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"},
@@ -20,7 +20,7 @@
2020
"mime": {:hex, :mime, "2.0.7", "b8d739037be7cd402aee1ba0306edfdef982687ee7e9859bee6198c1e7e2f128", [:mix], [], "hexpm", "6171188e399ee16023ffc5b76ce445eb6d9672e2e241d2df6050f3c771e80ccd"},
2121
"nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"},
2222
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
23-
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
23+
"nx": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "nx"]},
2424
"nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"},
2525
"nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"},
2626
"plug": {:hex, :plug, "1.18.1", "5067f26f7745b7e31bc3368bc1a2b818b9779faa959b49c934c17730efc911cf", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "57a57db70df2b422b564437d2d33cf8d33cd16339c1edb190cd11b1a3a546cc2"},
@@ -30,11 +30,11 @@
3030
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
3131
"ranch": {:hex, :ranch, "1.8.1", "208169e65292ac5d333d6cdbad49388c1ae198136e4697ae2f474697140f201c", [:make, :rebar3], [], "hexpm", "aed58910f4e21deea992a67bf51632b6d60114895eb03bb392bb733064594dd0"},
3232
"rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"},
33-
"safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"},
33+
"safetensors": {:git, "https://github.com/elixir-nx/safetensors.git", "29df313b91aba3ddf9d85aa2f3445db5e8d2622b", []},
3434
"stb_image": {:hex, :stb_image, "0.6.10", "76975279e2a130f53dc670bf6f6b1cdc4fbd7ab6293053e88e7fb6a7eae0e836", [:make, :mix], [{:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.8", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "26125372cfeda209084d3670417fab6819cfccd0e66c657678ecc48314369e8d"},
3535
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
3636
"tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"},
37-
"torchx": {:hex, :torchx, "0.10.2", "4b8529bfc4b0e641232497c99ef6d2508e652198840b212373333361352f0bae", [:mix], [{:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "cad541c64df8ddcbf50d9b0f212961632361a03050c8e01493f0fc8d4fed96d9"},
37+
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "torchx"]},
3838
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
3939
"unzip": {:hex, :unzip, "0.13.0", "bf5ec6ac6063c69e6ec54c8b4a3b8dcd7a2719d28d10d7025776ab107957cde9", [:mix], [], "hexpm", "4bcb9892ecbf2042606b43ab685a1bffe03c14003e6246f5453db2c829237fd9"},
4040
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},

0 commit comments

Comments
 (0)