diff --git a/lib/bumblebee/utils/http.ex b/lib/bumblebee/utils/http.ex index 242c6d90..86129b11 100644 --- a/lib/bumblebee/utils/http.ex +++ b/lib/bumblebee/utils/http.ex @@ -121,7 +121,7 @@ defmodule Bumblebee.Utils.HTTP do last_step_bucket = if step, do: div(last_percent, step), else: last_percent if step_bucket > last_step_bucket or percent == 100 do - ProgressBar.render(state.size, state.total_size, suffix: :bytes) + Bumblebee.Utils.ProgressBar.render(state.size, state.total_size, :bytes) end end diff --git a/lib/bumblebee/utils/progress_bar.ex b/lib/bumblebee/utils/progress_bar.ex new file mode 100644 index 00000000..b0004796 --- /dev/null +++ b/lib/bumblebee/utils/progress_bar.ex @@ -0,0 +1,68 @@ +defmodule Bumblebee.Utils.ProgressBar do + @moduledoc false + + # Reserve 2 chars for the start and end of the bar + @reserved_width 2 + + @start_end_char "|" + @filled_char "=" + @unfilled_char " " + + @doc """ + Renders a simple progress bar to the terminal. + The progress bar sizes to fill the entire width of the terminal. + """ + @spec render(number(), number()) :: :ok + @spec render(number(), number(), :none | :bytes) :: :ok + def render(count, total, unit \\ :none) do + percent = min(max(count / total, 0), 1) + + formatted_percent = String.pad_leading("#{trunc(percent * 100)}%", 4) + formatted_progress = " " <> formatted_progress(count, total, unit) + + reserved_width = + @reserved_width + String.length(formatted_progress) + String.length(formatted_percent) + + bar_width = max(terminal_width() - reserved_width, 0) + + filled_char_count = trunc(percent * bar_width) + unfilled_char_count = max(bar_width - filled_char_count, 0) + + filled_chars = String.duplicate(@filled_char, filled_char_count) + unfilled_chars = String.duplicate(@unfilled_char, unfilled_char_count) + + output = + @start_end_char <> + filled_chars <> + unfilled_chars <> + @start_end_char <> + formatted_percent <> + formatted_progress + + IO.write(output) + end + + defp formatted_progress(progress, total, :none) do + "#{progress}/#{total}" + end + + defp formatted_progress(progress, total, :bytes) do + {unit, divisor} = bytes_unit_and_divisor(total) + "#{trunc(progress / divisor)}/#{trunc(total / divisor)}#{unit}" + end + + defp bytes_unit_and_divisor(count) do + cond do + count > 1_000_000 -> {"MB", 1_000_000} + count > 1_000 -> {"KB", 1_000} + true -> {"B", 1} + end + end + + defp terminal_width do + case :io.columns() do + {:ok, width} -> width + {:error, _} -> 80 + end + end +end diff --git a/mix.exs b/mix.exs index 7b9b0b4b..8bef84c6 100644 --- a/mix.exs +++ b/mix.exs @@ -45,7 +45,6 @@ defmodule Bumblebee.MixProject do {:safetensors, "~> 0.1.3"}, {:jason, "~> 1.4.0"}, {:unzip, "~> 0.12.0 or ~> 0.13.0"}, - {:progress_bar, "~> 3.0"}, {:stb_image, "~> 0.6.0", only: :test}, {:bypass, "~> 2.1", only: :test}, {:ex_doc, "~> 0.28", only: :dev, runtime: false}, diff --git a/mix.lock b/mix.lock index 1f5ab16c..35739be7 100644 --- a/mix.lock +++ b/mix.lock @@ -7,7 +7,6 @@ "cowboy": {:hex, :cowboy, "2.14.2", "4008be1df6ade45e4f2a4e9e2d22b36d0b5aba4e20b0a0d7049e28d124e34847", [:make, :rebar3], [{:cowlib, ">= 2.16.0 and < 3.0.0", [hex: :cowlib, repo: "hexpm", optional: false]}, {:ranch, ">= 1.8.0 and < 3.0.0", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "569081da046e7b41b5df36aa359be71a0c8874e5b9cff6f747073fc57baf1ab9"}, "cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"}, "cowlib": {:hex, :cowlib, "2.16.0", "54592074ebbbb92ee4746c8a8846e5605052f29309d3a873468d76cdf932076f", [:make, :rebar3], [], "hexpm", "7f478d80d66b747344f0ea7708c187645cfcc08b11aa424632f78e25bf05db51"}, - "decimal": {:hex, :decimal, "2.3.0", "3ad6255aa77b4a3c4f818171b12d237500e63525c2fd056699967a3e7ea20f62", [:mix], [], "hexpm", "a4d66355cb29cb47c3cf30e71329e58361cfcb37c34235ef3bf1d7bf3773aeac"}, "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.39.1", "e19d356a1ba1e8f8cfc79ce1c3f83884b6abfcb79329d435d4bbb3e97ccc286e", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "8abf0ed3e3ca87c0847dfc4168ceab5bedfe881692f1b7c45f4a11b232806865"}, @@ -27,7 +26,6 @@ "plug_cowboy": {:hex, :plug_cowboy, "2.7.4", "729c752d17cf364e2b8da5bdb34fb5804f56251e88bb602aff48ae0bd8673d11", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "9b85632bd7012615bae0a5d70084deb1b25d2bcbb32cab82d1e9a1e023168aa3"}, "plug_crypto": {:hex, :plug_crypto, "2.1.1", "19bda8184399cb24afa10be734f84a16ea0a2bc65054e23a62bb10f06bc89491", [:mix], [], "hexpm", "6470bce6ffe41c8bd497612ffde1a7e4af67f36a15eea5f921af71cf3e11247c"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, - "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, "ranch": {:hex, :ranch, "1.8.1", "208169e65292ac5d333d6cdbad49388c1ae198136e4697ae2f474697140f201c", [:make, :rebar3], [], "hexpm", "aed58910f4e21deea992a67bf51632b6d60114895eb03bb392bb733064594dd0"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, diff --git a/test/bumblebee/utils/progress_bar_test.exs b/test/bumblebee/utils/progress_bar_test.exs new file mode 100644 index 00000000..4f1ea041 --- /dev/null +++ b/test/bumblebee/utils/progress_bar_test.exs @@ -0,0 +1,45 @@ +defmodule Bumblebee.Utils.ProgressBarTest do + use ExUnit.Case, async: true + + alias Bumblebee.Utils.ProgressBar + alias ExUnit.CaptureIO + + test "renders various widths of progress bars" do + assert "|================================= | 50% 0.2/0.4" == + CaptureIO.capture_io(fn -> ProgressBar.render(0.2, 0.4) end) + + assert "| | 0% 0/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(0, 10) end) + + assert "|================================== | 50% 5/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(5, 10) end) + + assert "|=============================================================== | 95% 9.5/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(9.5, 10) end) + + assert "|================================================================ | 99% 9.999/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(9.999, 10) end) + + assert "|====================================================================|100% 10/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(10, 10) end) + end + + test "renders bars when unexpected inputs given" do + assert "| | 0% -10/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(-10, 10) end) + + assert "|====================================================================|100% 20/10" == + CaptureIO.capture_io(fn -> ProgressBar.render(20, 10) end) + end + + test "formats byte counts" do + assert "|====== | 10% 10/100B" == + CaptureIO.capture_io(fn -> ProgressBar.render(10, 100, :bytes) end) + + assert "|================================== | 50% 1/2KB" == + CaptureIO.capture_io(fn -> ProgressBar.render(1_000, 2_000, :bytes) end) + + assert "|================================== | 50% 1/2MB" == + CaptureIO.capture_io(fn -> ProgressBar.render(1_000_000, 2_000_000, :bytes) end) + end +end