Last active
April 16, 2025 09:25
-
-
Save reyoung/ff81135c10ffa2b98850413205231ea5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <iostream> | |
| #include <variant> | |
| #include <memory> | |
| #include <list> | |
| #include <algorithm> | |
| #include <numeric> | |
| struct Shape { | |
| int n_{1}; | |
| int rows_; | |
| }; | |
| struct Parameter { | |
| std::string name_; | |
| Shape shape_; | |
| }; | |
| struct ParameterSlice { | |
| const Parameter *param_; | |
| int begin_; | |
| int end_; | |
| }; | |
| struct Pad { | |
| int rows_; | |
| }; | |
| using Entry = std::variant<ParameterSlice, Pad>; | |
| static int free_space(const Entry &e) { | |
| return std::visit([](auto &r) -> int { | |
| using T = std::decay_t<decltype(r)>; | |
| if constexpr (std::is_same_v<T, ParameterSlice>) { | |
| return 0; | |
| } else { | |
| return r.rows_; | |
| } | |
| }, e); | |
| } | |
| struct Bucket { | |
| std::list<Entry> entries_; | |
| explicit Bucket(int size) { | |
| entries_.emplace_back(Pad{.rows_ = size}); | |
| } | |
| [[nodiscard]] int left_free_space() const { | |
| return free_space(entries_.front()); | |
| } | |
| [[nodiscard]] int right_free_space() const { | |
| return free_space(entries_.back()); | |
| } | |
| [[nodiscard]] int max_free_space() const { | |
| return std::accumulate(entries_.begin(), entries_.end(), 0, [](int acc, const Entry &e) { | |
| return std::max(acc, free_space(e)); | |
| }); | |
| } | |
| [[nodiscard]] bool all_free() const { | |
| auto it = entries_.begin(); | |
| ++it; | |
| return it == this->entries_.end() && std::holds_alternative<Pad>(entries_.front()); | |
| } | |
| void allocate_right_slice(const Parameter ¶m, int n) { | |
| auto &b = entries_.back(); | |
| if (!std::holds_alternative<Pad>(b)) { | |
| std::cerr << "last entry is not Pad" << std::endl; | |
| exit(1); | |
| } | |
| auto &p = std::get<Pad>(b); | |
| auto pad_it = entries_.end(); | |
| --pad_it; | |
| p.rows_ -= n * param.shape_.rows_; | |
| entries_.emplace_back(ParameterSlice{ | |
| .param_ = ¶m, | |
| .begin_ = 0, | |
| .end_ = n | |
| }); | |
| if (p.rows_ == 0) { | |
| entries_.erase(pad_it); | |
| } | |
| } | |
| void allocate_left_slice(const Parameter ¶m, int n) { | |
| auto &b = entries_.front(); | |
| if (!std::holds_alternative<Pad>(b)) { | |
| std::cerr << "first entry is not Pad" << std::endl; | |
| exit(1); | |
| } | |
| auto &p = std::get<Pad>(b); | |
| auto pad_it = entries_.begin(); | |
| p.rows_ -= n * param.shape_.rows_; | |
| entries_.emplace_front(ParameterSlice{ | |
| .param_ = ¶m, | |
| .begin_ = param.shape_.n_ - n, | |
| .end_ = param.shape_.n_ | |
| }); | |
| if (p.rows_ == 0) { | |
| entries_.erase(pad_it); | |
| } | |
| } | |
| void allocate_full(const Parameter ¶m, int begin, int end) { | |
| if (!all_free()) { | |
| std::cerr << "bucket is not all free" << std::endl; | |
| exit(1); | |
| } | |
| auto &b = entries_.front(); | |
| b = ParameterSlice{ | |
| .param_ = ¶m, | |
| .begin_ = begin, | |
| .end_ = end | |
| }; | |
| } | |
| void allocate(const Parameter ¶m) { | |
| auto r = param.shape_.rows_ * param.shape_.n_; | |
| auto candidate_pad_it = entries_.end(); | |
| int free_space = std::numeric_limits<int>::max(); | |
| for (auto it = entries_.begin(); it != entries_.end(); ++it) { | |
| const Entry &e = (*it); | |
| if (!std::holds_alternative<Pad>(e)) { | |
| continue; | |
| } | |
| auto &p = std::get<Pad>(e); | |
| if (p.rows_ < free_space && p.rows_ >= r) { | |
| candidate_pad_it = it; | |
| free_space = p.rows_; | |
| } | |
| } | |
| if (candidate_pad_it == entries_.end()) { | |
| std::cerr << "No free space for parameter: " << param.name_ << std::endl; | |
| exit(1); | |
| } | |
| entries_.insert(candidate_pad_it, ParameterSlice{ | |
| .param_ = ¶m, | |
| .begin_ = 0, | |
| .end_ = 1 | |
| }); | |
| if (free_space == param.shape_.rows_) { | |
| entries_.erase(candidate_pad_it); | |
| } else { | |
| auto &p = std::get<Pad>(*candidate_pad_it); | |
| p.rows_ -= param.shape_.rows_; | |
| } | |
| } | |
| }; | |
| struct Buckets { | |
| Buckets(int n, int size) : max_possible_size_(size) { | |
| buckets_.reserve(n); | |
| for (int i = 0; i < n; ++i) { | |
| buckets_.emplace_back(size); | |
| } | |
| } | |
| int allocate_parameters(const Parameter ¶m) { | |
| if (param.shape_.n_ == 1) { | |
| // best fit buckets_ | |
| auto [b_id, free_space] = first_free_space_bucket(param.shape_.rows_); | |
| if (free_space < param.shape_.rows_) { | |
| return param.shape_.rows_ - free_space; | |
| } | |
| auto &bucket = buckets_[b_id]; | |
| bucket.allocate(param); | |
| return 0; | |
| } | |
| if (max_possible_size_ % param.shape_.rows_ != 0) { | |
| std::cerr << "max_possible_size_ % n != 0" << std::endl; | |
| exit(1); | |
| } | |
| int n_per_bucket = max_possible_size_ / param.shape_.rows_; | |
| // 如果一个bucket可以放下所有的slice | |
| if (n_per_bucket >= param.shape_.n_) { | |
| auto [b_id, free_space] = first_free_space_bucket(param.shape_.rows_ * param.shape_.n_); | |
| if (free_space >= param.shape_.rows_) { | |
| auto &bucket = buckets_[b_id]; | |
| bucket.allocate(param); | |
| return 0; | |
| } | |
| } | |
| int n_free_blocks = param.shape_.n_ / n_per_bucket; | |
| int n_reminder = param.shape_.n_ % n_per_bucket; | |
| int begin = 0; | |
| while (true) { | |
| int free_begin = find_n_contiguous_free_blocks(begin, n_free_blocks); | |
| if (free_begin < 0) { | |
| return param.shape_.rows_; | |
| } | |
| int left_n = 0; | |
| int right_n = 0; | |
| // can n reminder fit prev/next bucket? | |
| if (!can_reminder_fit(n_reminder, param.shape_.rows_, free_begin - 1, free_begin + n_free_blocks, | |
| left_n, right_n)) { | |
| begin = free_begin + 1; | |
| // search next n contiguous free blocks | |
| continue; | |
| } | |
| // insert left n slices | |
| int slice_begin = 0; | |
| if (left_n) { | |
| auto &b = buckets_[free_begin - 1]; | |
| b.allocate_right_slice(param, left_n); | |
| slice_begin += left_n; | |
| } | |
| for (int i = 0; i < n_free_blocks; ++i) { | |
| auto &b = buckets_.at(free_begin + i); | |
| b.allocate_full(param, slice_begin, slice_begin + n_per_bucket); | |
| slice_begin += n_per_bucket; | |
| } | |
| if (right_n) { | |
| auto &b = buckets_.at(free_begin + n_free_blocks); | |
| b.allocate_left_slice(param, right_n); | |
| } | |
| return 0; | |
| } | |
| } | |
| [[nodiscard]] bool can_reminder_fit(int reminder, int rows, int left, int right, int &left_n, int &right_n) { | |
| if (reminder == 0) { | |
| left_n = 0; | |
| right_n = 0; | |
| return true; | |
| } | |
| if (left < 0) { | |
| left_n = 0; | |
| // must fit all in right | |
| if (right >= static_cast<int>(buckets_.size())) { | |
| return false; | |
| } | |
| auto &bucket = buckets_[right]; | |
| if (bucket.left_free_space() >= reminder * rows) { | |
| right_n = reminder; | |
| return true; | |
| } else { | |
| return false; | |
| } | |
| } | |
| if (right >= static_cast<int>(buckets_.size())) { | |
| right_n = 0; | |
| // must fit all in left | |
| auto &bucket = buckets_[left]; | |
| if (bucket.right_free_space() >= reminder * rows) { | |
| left_n = reminder; | |
| return true; | |
| } else { | |
| return false; | |
| } | |
| } | |
| int left_free_space = buckets_[left].right_free_space(); | |
| int right_free_space = buckets_[right].left_free_space(); | |
| for (int i = 0; i < reminder; ++i) { | |
| left_n = reminder - i; | |
| right_n = i; | |
| if (left_n * rows >= left_free_space) { | |
| continue; | |
| } | |
| if (right_n * rows >= right_free_space) { | |
| continue; | |
| } | |
| return true; | |
| } | |
| return false; | |
| } | |
| // find n contiguous free blocks | |
| // return the begin index of the first block | |
| // if not found, return -1 | |
| [[nodiscard]] | |
| int find_n_contiguous_free_blocks(int from, int n) const { | |
| int last_non_free = from - 1; | |
| for (int i = from; i < static_cast<int>(buckets_.size()); ++i) { | |
| if (!buckets_[i].all_free()) { | |
| last_non_free = i; | |
| continue; | |
| } | |
| if (i - last_non_free == n) { | |
| return ++last_non_free; | |
| } | |
| } | |
| return -1; | |
| } | |
| [[nodiscard]] std::tuple<size_t, int> first_free_space_bucket(int limit) const { | |
| for (size_t i = 0; i < buckets_.size(); ++i) { | |
| int cur_free_space = buckets_[i].max_free_space(); | |
| if (cur_free_space >= limit) { | |
| return {i, cur_free_space}; | |
| } | |
| } | |
| return {0, 0}; | |
| } | |
| std::vector<Bucket> buckets_; | |
| int max_possible_size_; | |
| }; | |
| inline int least_common_multiply(int a, int b) { | |
| return (a * b) / std::gcd(a, b); | |
| } | |
| std::ostream &operator<<(std::ostream &os, const Buckets &b) { | |
| int total = b.max_possible_size_ * b.buckets_.size(); | |
| int total_pad = 0; | |
| for (auto &buck : b.buckets_) { | |
| total_pad = std::accumulate(buck.entries_.begin(), buck.entries_.end(), total_pad, [](int acc, const Entry &e) { | |
| return acc + free_space(e); | |
| }); | |
| } | |
| os << "pad " << total_pad << ", rate " << float(total_pad) / total; | |
| return os; | |
| } | |
| int main() { | |
| std::vector<Parameter> parameters; | |
| for (int layer_id = 0; layer_id < 32; ++layer_id) { | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "router_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .rows_ = 64 | |
| } | |
| }); | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "kv_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .n_ = 2, | |
| .rows_ = 512, | |
| } | |
| }); | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "experts.fc2_0_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .n_ = 64, | |
| .rows_ = 1536, | |
| } | |
| }); | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "q_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .rows_ = 2560, | |
| } | |
| }); | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "out_proj_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .rows_ = 2560, | |
| } | |
| }); | |
| parameters.emplace_back(Parameter{ | |
| .name_ = "experts.fc2_0_" + std::to_string(layer_id), | |
| .shape_ = Shape{ | |
| .n_ = 64, | |
| .rows_ = 6144, | |
| } | |
| }); | |
| } | |
| constexpr static int dp_size = 256; | |
| int lcm = std::accumulate(parameters.begin(), | |
| parameters.end(), | |
| 1, | |
| [](int acc, const Parameter &p) { | |
| if (p.shape_.n_ == 1) { | |
| return acc; | |
| } | |
| return least_common_multiply(acc, p.shape_.rows_); | |
| }); | |
| int min_rows_per_bucket = lcm; | |
| std::cerr << "min_rows_per_bucket: " << min_rows_per_bucket << std::endl; | |
| while (true) { | |
| Buckets b(dp_size, min_rows_per_bucket); | |
| bool error = false; | |
| for (auto &p : parameters) { | |
| int stride = b.allocate_parameters(p); | |
| if (stride == 0) { | |
| continue; | |
| } | |
| std::cerr << "cannot allocate " << stride << " " << p.name_ << std::endl; | |
| min_rows_per_bucket += least_common_multiply(lcm, stride); | |
| std::cerr << "increase min_rows_per_bucket to " << min_rows_per_bucket << std::endl; | |
| error = true; | |
| break; | |
| } | |
| if (!error) { | |
| std::cerr << b << std::endl; | |
| break; | |
| } | |
| } | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment