[go: nahoru, domu]

Skip to content

Commit

Permalink
Augments the problem formulation to limit the # of "departures" from …
Browse files Browse the repository at this point in the history
…the trivial sharding strategy.

PiperOrigin-RevId: 562823026
  • Loading branch information
tensorflower-gardener committed Sep 5, 2023
1 parent e0dae96 commit 8ebd0f8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2348,16 +2348,22 @@ AutoShardingSolverResult CallSolver(
const StrategyVector* strategies = leaf_strategies[i];
request.instruction_names.push_back(absl::StrCat(
instructions.at(strategies->instruction_id)->name(), " (id: ", i, ")"));
std::vector<double> ci, di, mi;
std::vector<double> ci, di, mi, pi;
for (NodeStrategyIdx j = 0; j < strategies->leaf_vector.size(); ++j) {
ci.push_back(strategies->leaf_vector[j].compute_cost);
di.push_back(strategies->leaf_vector[j].communication_cost +
const ShardingStrategy& strategy = strategies->leaf_vector[j];
const HloSharding& sharding = strategy.output_sharding;
ci.push_back(strategy.compute_cost);
di.push_back(strategy.communication_cost +
cost_graph.extra_node_costs_[i][j]);
mi.push_back(strategies->leaf_vector[j].memory_cost);
mi.push_back(strategy.memory_cost);
// TOOD(moffitt): Revisit the default strategy below, which is currently
// defined as the "trivial sharding" in hlo_sharding.h
pi.push_back(sharding.IsReplicated() && !sharding.IsManual() ? 0.0 : 1.0);
}
request.c.push_back(ci);
request.d.push_back(di);
request.m.push_back(mi);
request.p.push_back(pi);
}

// Serialize special edges that forces a alias pair have the same sharding
Expand Down Expand Up @@ -2466,6 +2472,7 @@ AutoShardingSolverResult CallSolver(
<< ")";
LOG(INFO) << "Total Cost: " << evaluation.total.cost()
<< " (lower bound: " << evaluation.lower_bound.cost() << ")";
LOG(INFO) << "Total Departures: " << evaluation.total_departures;
LOG(INFO) << "Total Violations: " << evaluation.violation_codes.size();
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,18 @@ AutoShardingSolverResult CallORToolsSolver(
}
}
}
if (request.max_departures) {
MPConstraint* constraint = solver->MakeRowConstraint(
0, *request.max_departures,
absl::StrCat("departures <= ", *request.max_departures));
for (NodeIdx i = 0; i < request.num_nodes; ++i) {
for (NodeStrategyIdx j = 0; j < s[i].size(); ++j) {
double accumulated_coefficient = constraint->GetCoefficient(s[i][j]);
constraint->SetCoefficient(s[i][j],
accumulated_coefficient + request.p[i][j]);
}
}
}

if (!request.s_hint.empty()) {
std::vector<std::pair<const MPVariable*, double>> hint;
Expand Down Expand Up @@ -578,7 +590,8 @@ double CostComponents::cost() const {
bool AutoShardingEvaluation::operator==(
const AutoShardingEvaluation& other) const {
return violation_codes == other.violation_codes && total == other.total &&
lower_bound == other.lower_bound;
lower_bound == other.lower_bound &&
total_departures == other.total_departures;
}

AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
Expand Down Expand Up @@ -609,6 +622,13 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
evaluation.violation_codes.insert(kInfiniteCostViolationCode);
}
}
for (NodeIdx i = 0; i < request.num_nodes; ++i) {
evaluation.total_departures += request.p[i][s_val[i]];
if (request.max_departures &&
evaluation.total_departures > *request.max_departures) {
evaluation.violation_codes.insert(kMaxDeparturesViolationCode);
}
}
if (request.memory_budget > 0) {
double total_overbudget = 0.0;
double lower_bound_overbudget = 0.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ struct AutoShardingSolverRequest {
std::vector<std::vector<double>> c;
std::vector<std::vector<double>> d;
std::vector<std::vector<double>> m;
std::vector<std::vector<double>> p;
std::vector<std::vector<double>> r;
std::vector<std::pair<NodeIdx, NodeIdx>> a;
std::vector<std::vector<double>> v;
std::vector<std::string> instruction_names;
std::optional<int64_t> solver_timeout_in_seconds;
std::optional<double> overbudget_coeff;
std::optional<double> max_departures;
bool crash_at_infinity_costs_check = false;
bool compute_iis = true;
double saltiplier = 0.0001; // Modifies each objective term by at most 0.01%
Expand All @@ -72,8 +74,9 @@ AutoShardingSolverResult CallORToolsSolver(
enum AutoShardingViolationCode {
kAliasViolationCode, // Some node's strategy does not match its alias
kFollowerViolationCode, // Some node's strategy does not match its follower
kInfiniteCostViolationCode, // Some node or edge incurs infinite cost
kMemoryViolationCode, // The solution eclipses the memory budget
kInfiniteCostViolationCode, // Some node or edge incurs infinite cost
kMemoryViolationCode, // The solution eclipses the memory budget
kMaxDeparturesViolationCode, // The solution has too many sharding departures
};

struct CostComponents {
Expand All @@ -97,6 +100,9 @@ struct AutoShardingEvaluation {
CostComponents total;
CostComponents lower_bound;

// How many instructions departed from the "default" sharding strategy.
double total_departures = 0.0;

bool operator==(const AutoShardingEvaluation& other) const;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() {
{300000, 310000, 320000, 330000},
{400000, 410000, 420000, 430000},
{500000, 510000, 520000}};
request.p = {{1.0, 0.0, 1.0, 1.0},
{1.0, 0.0, 1.0},
{1.0, 0.0, 1.0, 1.0},
{1.0, 0.0, 1.0, 1.0},
{1.0, 0.0, 1.0}};
request.r = {{1000, 1100, 1200, 1300,
2000, 2100, 2200, 2300,
3000, 3100, 3200, 3300,
Expand Down Expand Up @@ -99,6 +104,21 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesMaxDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.max_departures = 3.0;

const AutoShardingSolverResult result = CallORToolsSolver(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const std::vector<EdgeStrategyIdx> e_val = {1, 1};
const double objective_value = 7872.0;
const AutoShardingSolverResult expected_result = {
std::make_tuple(
std::move(s_val), std::move(e_val), objective_value), false};
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.c[0][0] = request.c[0][1] = request.c[0][2] = kInfinityCost;
Expand Down Expand Up @@ -180,6 +200,7 @@ TEST(AutoShardingEvaluatorTest, NoViolations) {
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -205,6 +226,7 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) {
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.lower_bound.overbudget_cost = 9000000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -227,6 +249,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) {
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 2.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -249,6 +272,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) {
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 4.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -271,6 +295,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) {
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -294,6 +319,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) {
expected_evaluation.lower_bound.computation_cost = 153.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand All @@ -317,6 +343,31 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) {
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.max_departures = 2.0;
const std::vector<NodeStrategyIdx> s_val = {3, 1, 2, 2, 1};
const std::vector<EdgeStrategyIdx> e_val = {14, 6};
const double objective_value = 12149.0;
const AutoShardingSolverResult result = {
std::make_tuple(
std::move(s_val), std::move(e_val), objective_value), false};

const AutoShardingEvaluation evaluation = Evaluate(request, result);

AutoShardingEvaluation expected_evaluation;
expected_evaluation.violation_codes = {kMaxDeparturesViolationCode};
expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51
expected_evaluation.total.communication_cost = 1590.0; // 130+210+320+420+510
expected_evaluation.total.resharding_cost = 10400.0; // 4200+6200
expected_evaluation.lower_bound.computation_cost = 150.0;
expected_evaluation.lower_bound.communication_cost = 1500.0;
expected_evaluation.lower_bound.resharding_cost = 6000.0;
expected_evaluation.total_departures = 3.0;
EXPECT_EQ(evaluation, expected_evaluation);
}

Expand Down

0 comments on commit 8ebd0f8

Please sign in to comment.