Skip to content

Instantly share code, notes, and snippets.

@reyoung
Last active April 16, 2025 09:25
Show Gist options
  • Select an option

  • Save reyoung/ff81135c10ffa2b98850413205231ea5 to your computer and use it in GitHub Desktop.

Select an option

Save reyoung/ff81135c10ffa2b98850413205231ea5 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <variant>
#include <memory>
#include <list>
#include <algorithm>
#include <numeric>
struct Shape {
int n_{1};
int rows_;
};
struct Parameter {
std::string name_;
Shape shape_;
};
struct ParameterSlice {
const Parameter *param_;
int begin_;
int end_;
};
struct Pad {
int rows_;
};
using Entry = std::variant<ParameterSlice, Pad>;
static int free_space(const Entry &e) {
return std::visit([](auto &r) -> int {
using T = std::decay_t<decltype(r)>;
if constexpr (std::is_same_v<T, ParameterSlice>) {
return 0;
} else {
return r.rows_;
}
}, e);
}
struct Bucket {
std::list<Entry> entries_;
explicit Bucket(int size) {
entries_.emplace_back(Pad{.rows_ = size});
}
[[nodiscard]] int left_free_space() const {
return free_space(entries_.front());
}
[[nodiscard]] int right_free_space() const {
return free_space(entries_.back());
}
[[nodiscard]] int max_free_space() const {
return std::accumulate(entries_.begin(), entries_.end(), 0, [](int acc, const Entry &e) {
return std::max(acc, free_space(e));
});
}
[[nodiscard]] bool all_free() const {
auto it = entries_.begin();
++it;
return it == this->entries_.end() && std::holds_alternative<Pad>(entries_.front());
}
void allocate_right_slice(const Parameter &param, int n) {
auto &b = entries_.back();
if (!std::holds_alternative<Pad>(b)) {
std::cerr << "last entry is not Pad" << std::endl;
exit(1);
}
auto &p = std::get<Pad>(b);
auto pad_it = entries_.end();
--pad_it;
p.rows_ -= n * param.shape_.rows_;
entries_.emplace_back(ParameterSlice{
.param_ = &param,
.begin_ = 0,
.end_ = n
});
if (p.rows_ == 0) {
entries_.erase(pad_it);
}
}
void allocate_left_slice(const Parameter &param, int n) {
auto &b = entries_.front();
if (!std::holds_alternative<Pad>(b)) {
std::cerr << "first entry is not Pad" << std::endl;
exit(1);
}
auto &p = std::get<Pad>(b);
auto pad_it = entries_.begin();
p.rows_ -= n * param.shape_.rows_;
entries_.emplace_front(ParameterSlice{
.param_ = &param,
.begin_ = param.shape_.n_ - n,
.end_ = param.shape_.n_
});
if (p.rows_ == 0) {
entries_.erase(pad_it);
}
}
void allocate_full(const Parameter &param, int begin, int end) {
if (!all_free()) {
std::cerr << "bucket is not all free" << std::endl;
exit(1);
}
auto &b = entries_.front();
b = ParameterSlice{
.param_ = &param,
.begin_ = begin,
.end_ = end
};
}
void allocate(const Parameter &param) {
auto r = param.shape_.rows_ * param.shape_.n_;
auto candidate_pad_it = entries_.end();
int free_space = std::numeric_limits<int>::max();
for (auto it = entries_.begin(); it != entries_.end(); ++it) {
const Entry &e = (*it);
if (!std::holds_alternative<Pad>(e)) {
continue;
}
auto &p = std::get<Pad>(e);
if (p.rows_ < free_space && p.rows_ >= r) {
candidate_pad_it = it;
free_space = p.rows_;
}
}
if (candidate_pad_it == entries_.end()) {
std::cerr << "No free space for parameter: " << param.name_ << std::endl;
exit(1);
}
entries_.insert(candidate_pad_it, ParameterSlice{
.param_ = &param,
.begin_ = 0,
.end_ = 1
});
if (free_space == param.shape_.rows_) {
entries_.erase(candidate_pad_it);
} else {
auto &p = std::get<Pad>(*candidate_pad_it);
p.rows_ -= param.shape_.rows_;
}
}
};
struct Buckets {
Buckets(int n, int size) : max_possible_size_(size) {
buckets_.reserve(n);
for (int i = 0; i < n; ++i) {
buckets_.emplace_back(size);
}
}
int allocate_parameters(const Parameter &param) {
if (param.shape_.n_ == 1) {
// best fit buckets_
auto [b_id, free_space] = first_free_space_bucket(param.shape_.rows_);
if (free_space < param.shape_.rows_) {
return param.shape_.rows_ - free_space;
}
auto &bucket = buckets_[b_id];
bucket.allocate(param);
return 0;
}
if (max_possible_size_ % param.shape_.rows_ != 0) {
std::cerr << "max_possible_size_ % n != 0" << std::endl;
exit(1);
}
int n_per_bucket = max_possible_size_ / param.shape_.rows_;
// 如果一个bucket可以放下所有的slice
if (n_per_bucket >= param.shape_.n_) {
auto [b_id, free_space] = first_free_space_bucket(param.shape_.rows_ * param.shape_.n_);
if (free_space >= param.shape_.rows_) {
auto &bucket = buckets_[b_id];
bucket.allocate(param);
return 0;
}
}
int n_free_blocks = param.shape_.n_ / n_per_bucket;
int n_reminder = param.shape_.n_ % n_per_bucket;
int begin = 0;
while (true) {
int free_begin = find_n_contiguous_free_blocks(begin, n_free_blocks);
if (free_begin < 0) {
return param.shape_.rows_;
}
int left_n = 0;
int right_n = 0;
// can n reminder fit prev/next bucket?
if (!can_reminder_fit(n_reminder, param.shape_.rows_, free_begin - 1, free_begin + n_free_blocks,
left_n, right_n)) {
begin = free_begin + 1;
// search next n contiguous free blocks
continue;
}
// insert left n slices
int slice_begin = 0;
if (left_n) {
auto &b = buckets_[free_begin - 1];
b.allocate_right_slice(param, left_n);
slice_begin += left_n;
}
for (int i = 0; i < n_free_blocks; ++i) {
auto &b = buckets_.at(free_begin + i);
b.allocate_full(param, slice_begin, slice_begin + n_per_bucket);
slice_begin += n_per_bucket;
}
if (right_n) {
auto &b = buckets_.at(free_begin + n_free_blocks);
b.allocate_left_slice(param, right_n);
}
return 0;
}
}
[[nodiscard]] bool can_reminder_fit(int reminder, int rows, int left, int right, int &left_n, int &right_n) {
if (reminder == 0) {
left_n = 0;
right_n = 0;
return true;
}
if (left < 0) {
left_n = 0;
// must fit all in right
if (right >= static_cast<int>(buckets_.size())) {
return false;
}
auto &bucket = buckets_[right];
if (bucket.left_free_space() >= reminder * rows) {
right_n = reminder;
return true;
} else {
return false;
}
}
if (right >= static_cast<int>(buckets_.size())) {
right_n = 0;
// must fit all in left
auto &bucket = buckets_[left];
if (bucket.right_free_space() >= reminder * rows) {
left_n = reminder;
return true;
} else {
return false;
}
}
int left_free_space = buckets_[left].right_free_space();
int right_free_space = buckets_[right].left_free_space();
for (int i = 0; i < reminder; ++i) {
left_n = reminder - i;
right_n = i;
if (left_n * rows >= left_free_space) {
continue;
}
if (right_n * rows >= right_free_space) {
continue;
}
return true;
}
return false;
}
// find n contiguous free blocks
// return the begin index of the first block
// if not found, return -1
[[nodiscard]]
int find_n_contiguous_free_blocks(int from, int n) const {
int last_non_free = from - 1;
for (int i = from; i < static_cast<int>(buckets_.size()); ++i) {
if (!buckets_[i].all_free()) {
last_non_free = i;
continue;
}
if (i - last_non_free == n) {
return ++last_non_free;
}
}
return -1;
}
[[nodiscard]] std::tuple<size_t, int> first_free_space_bucket(int limit) const {
for (size_t i = 0; i < buckets_.size(); ++i) {
int cur_free_space = buckets_[i].max_free_space();
if (cur_free_space >= limit) {
return {i, cur_free_space};
}
}
return {0, 0};
}
std::vector<Bucket> buckets_;
int max_possible_size_;
};
inline int least_common_multiply(int a, int b) {
return (a * b) / std::gcd(a, b);
}
std::ostream &operator<<(std::ostream &os, const Buckets &b) {
int total = b.max_possible_size_ * b.buckets_.size();
int total_pad = 0;
for (auto &buck : b.buckets_) {
total_pad = std::accumulate(buck.entries_.begin(), buck.entries_.end(), total_pad, [](int acc, const Entry &e) {
return acc + free_space(e);
});
}
os << "pad " << total_pad << ", rate " << float(total_pad) / total;
return os;
}
int main() {
std::vector<Parameter> parameters;
for (int layer_id = 0; layer_id < 32; ++layer_id) {
parameters.emplace_back(Parameter{
.name_ = "router_" + std::to_string(layer_id),
.shape_ = Shape{
.rows_ = 64
}
});
parameters.emplace_back(Parameter{
.name_ = "kv_" + std::to_string(layer_id),
.shape_ = Shape{
.n_ = 2,
.rows_ = 512,
}
});
parameters.emplace_back(Parameter{
.name_ = "experts.fc2_0_" + std::to_string(layer_id),
.shape_ = Shape{
.n_ = 64,
.rows_ = 1536,
}
});
parameters.emplace_back(Parameter{
.name_ = "q_" + std::to_string(layer_id),
.shape_ = Shape{
.rows_ = 2560,
}
});
parameters.emplace_back(Parameter{
.name_ = "out_proj_" + std::to_string(layer_id),
.shape_ = Shape{
.rows_ = 2560,
}
});
parameters.emplace_back(Parameter{
.name_ = "experts.fc2_0_" + std::to_string(layer_id),
.shape_ = Shape{
.n_ = 64,
.rows_ = 6144,
}
});
}
constexpr static int dp_size = 256;
int lcm = std::accumulate(parameters.begin(),
parameters.end(),
1,
[](int acc, const Parameter &p) {
if (p.shape_.n_ == 1) {
return acc;
}
return least_common_multiply(acc, p.shape_.rows_);
});
int min_rows_per_bucket = lcm;
std::cerr << "min_rows_per_bucket: " << min_rows_per_bucket << std::endl;
while (true) {
Buckets b(dp_size, min_rows_per_bucket);
bool error = false;
for (auto &p : parameters) {
int stride = b.allocate_parameters(p);
if (stride == 0) {
continue;
}
std::cerr << "cannot allocate " << stride << " " << p.name_ << std::endl;
min_rows_per_bucket += least_common_multiply(lcm, stride);
std::cerr << "increase min_rows_per_bucket to " << min_rows_per_bucket << std::endl;
error = true;
break;
}
if (!error) {
std::cerr << b << std::endl;
break;
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment