Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,17 @@ add_executable(gpt2
example/gpt2/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/gpt2/checkpoint_loader.cc
example/common/tokenizer.cc
example/gpt2/checkpoint_loader.cc
)
link_infini_train_exe(gpt2)

add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/llama3/checkpoint_loader.cc
example/common/tokenizer.cc
example/llama3/checkpoint_loader.cc
)
link_infini_train_exe(llama3)

Expand Down
5 changes: 3 additions & 2 deletions example/gpt2/checkpoint_loader.cc
Comment thread
kilinchange marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
Expand All @@ -24,6 +22,9 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

Expand Down
74 changes: 73 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <format>
#include <memory>
#include <optional>
Expand All @@ -10,6 +11,7 @@
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/checkpoint/checkpoint.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
Expand All @@ -29,6 +31,7 @@
#ifdef PROFILE_MODE
#include "infini_train/include/profiler.h"
#endif
#include "infini_train/include/checkpoint/checkpoint_manager.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/global_module_hook_registry.h"
#include "infini_train/include/utils/precision_check_config.h"
Expand All @@ -39,6 +42,7 @@
#include "example/gpt2/checkpoint_loader.h"
#include "example/gpt2/config.h"

// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability.
// I/O
DEFINE_string(input_bin, "", "input .bin to train on");
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
Expand Down Expand Up @@ -77,6 +81,11 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving");
DEFINE_string(load, "", "checkpoint directory to resume from");
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
// precision check
DEFINE_string(
precision_check, "",
Expand Down Expand Up @@ -315,9 +324,55 @@ void Train(const nn::parallel::Rank &rank) {

auto impl = core::GetDeviceGuardImpl(device.type());

int start_step = 0;
TrainerState state;
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
.rank = rank,
.model = model,
.optimizer = optimizer,
.model_config = model_config,
.state = state,
.load_optimizer_state = false});
start_step = resume_result.global_step;
size_t consumed_batches = resume_result.consumed_batches;

// TODO(jym): Replace with Sampler abstraction when available.
// Skip dataloader to resume from the correct batch position.
if (consumed_batches > 0) {
size_t start = train_iter.BatchIndex();
// Each rank processes every ddp_world_size-th batch starting from its own rank.
// num_skips calculates how many ++ iterations to reach the saved batch position.
size_t num_skips = (consumed_batches - start) / ddp_world_size;
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
}

auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) {
SaveCheckpoint({
.save_dir = save_dir,
.global_step = global_step,
.consumed_batches = consumed_batches,
.last_lr = FLAGS_learning_rate,
.n_layer = model_config.n_layer,
.n_head = model_config.n_head,
.n_kv_head = model_config.n_kv_head,
.n_embd = model_config.n_embd,
.vocab_size = model_config.vocab_size,
.ddp_size = ddp_world_size,
.tp_size = tp_world_size,
.sp_size = sp_world_size,
.pp_size = pp_world_size,
.save_optimizer_state = FLAGS_save_optimizer_state,
.checkpoint_root_dir = FLAGS_save,
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
.rank = rank,
.model = *model,
.optimizer = *optimizer,
});
};

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

Expand Down Expand Up @@ -367,6 +422,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -397,6 +453,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -431,6 +488,15 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) {
std::filesystem::path step_dir
= std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1);
if (rank.IsParallel()) {
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(step_dir, step + 1);
}
}

// Save LoRA weights if enabled and path specified
Expand All @@ -439,6 +505,12 @@ void Train(const nn::parallel::Rank &rank) {
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
}

std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final";
if (rank.IsParallel()) {
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(final_dir, FLAGS_num_iteration);

#ifdef PROFILE_MODE
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("gpt2.records.log");
Expand Down
5 changes: 3 additions & 2 deletions example/llama3/checkpoint_loader.cc
Comment thread
kilinchange marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/llama3/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
#include "infini_train/include/nn/modules/transformer/mlp.h"
Expand All @@ -22,6 +20,9 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/llama3/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

Expand Down
76 changes: 75 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cstdlib>
#include <filesystem>
#include <format>
#include <memory>
#include <optional>
Expand All @@ -8,6 +9,8 @@
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/checkpoint/checkpoint.h"
#include "infini_train/include/checkpoint/checkpoint_manager.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
Expand Down Expand Up @@ -38,6 +41,7 @@
#include "example/llama3/checkpoint_loader.h"
#include "example/llama3/config.h"

// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability.
// I/O
DEFINE_string(input_bin, "", "input .bin to train on");
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
Expand Down Expand Up @@ -75,6 +79,12 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving");
DEFINE_string(load, "", "checkpoint directory to resume from");
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");

// precision check
DEFINE_string(
precision_check, "",
Expand Down Expand Up @@ -293,7 +303,54 @@ void Train(const nn::parallel::Rank &rank) {

auto impl = core::GetDeviceGuardImpl(device.type());

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
int start_step = 0;
TrainerState state;
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
.rank = rank,
.model = model,
.optimizer = optimizer,
.model_config = model_config,
.state = state,
.load_optimizer_state = true});

start_step = resume_result.global_step;
size_t consumed_batches = resume_result.consumed_batches;

// TODO(jym): Replace with Sampler abstraction when available.
// Skip dataloader to resume from the correct batch position.
if (consumed_batches > 0) {
size_t start = train_iter.BatchIndex();
// Each rank processes every ddp_world_size-th batch starting from its own rank.
// num_skips calculates how many ++ iterations to reach the saved batch position.
size_t num_skips = (consumed_batches - start) / ddp_world_size;
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
}

auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) {
SaveCheckpoint({
.save_dir = save_dir,
.global_step = global_step,
.consumed_batches = consumed_batches,
.last_lr = FLAGS_learning_rate,
.n_layer = model_config.n_layer,
.n_head = model_config.n_head,
.n_kv_head = model_config.n_kv_head,
.n_embd = model_config.n_embd,
.vocab_size = model_config.vocab_size,
.ddp_size = ddp_world_size,
.tp_size = tp_world_size,
.sp_size = sp_world_size,
.pp_size = pp_world_size,
.save_optimizer_state = FLAGS_save_optimizer_state,
.checkpoint_root_dir = FLAGS_save,
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
.rank = rank,
.model = *model,
.optimizer = *optimizer,
});
};

for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

Expand Down Expand Up @@ -343,6 +400,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -372,6 +430,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -406,6 +465,15 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) {
std::filesystem::path step_dir
= std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1);
if (rank.IsParallel()) {
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(step_dir, step + 1);
}
}

// Save LoRA weights if enabled and path specified
Expand All @@ -414,6 +482,12 @@ void Train(const nn::parallel::Rank &rank) {
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
}

std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final";
if (rank.IsParallel()) {
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(final_dir, FLAGS_num_iteration);

#ifdef PROFILE_MODE
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("llama3.records.log");
Expand Down
52 changes: 52 additions & 0 deletions infini_train/include/checkpoint/checkpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include <cstdint>
#include <filesystem>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>

namespace infini_train {
class Optimizer;
class Tensor;
namespace nn {
class Module;
}

struct TrainerState {
int64_t global_step = 0;
int64_t consumed_batches = 0;
// FIXME(jym): learning_rate should be restored from scheduler state, move `last_lr` from TrainerState to
// SchedulerState later
double last_lr = 0.0;
Comment thread
kilinchange marked this conversation as resolved.
int64_t n_layer = 0;
int64_t n_head = 0;
int64_t n_kv_head = 0;
int64_t n_embd = 0;
int64_t vocab_size = 0;
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
};

class Checkpoint {
public:
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
const TrainerState &state, bool save_optimizer_state);

static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
TrainerState &state, bool load_optimizer_state);

private:
static void SaveStateDict(const std::filesystem::path &path,
const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict);

static std::unordered_map<std::string, std::shared_ptr<Tensor>> LoadStateDict(const std::filesystem::path &path);

static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state);
static TrainerState LoadTrainerState(const std::filesystem::path &path);
};

} // namespace infini_train
Loading
Loading