[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA] Enable HloEvaluator for constant folding, also merged a few ope…
Browse files Browse the repository at this point in the history
…rations

from hlo_constant_folding to hlo_evaluator.

Additionally:
- In ShapeUtil::ForEachIndex:
    * fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0})
    * added test case for ForEachIndex.

- In HloEvaluator:
    * Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant.
    * Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor.
    * Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled.
    * Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type.
    * Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean.
    * Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior.
    * Change from std::max and std::min to std::fmax and std::fmin to handle NaNs.
    * Minor comments fix.

- Disables constant_folding and reshape-motion for ClientLibraryTestBase so that constant folding would not affect the intended code paths to be execercised by the test. In the longer term we plan change all Constants to Parameter and re-enable constant_folding in tests.

PiperOrigin-RevId: 157174708
  • Loading branch information
kayzhu authored and tensorflower-gardener committed May 26, 2017
1 parent 2b546d8 commit 405f70c
Show file tree
Hide file tree
Showing 12 changed files with 568 additions and 368 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,7 @@ template <typename NativeT>
ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions,
stride_config.step, init_function);
} else {
// For scalars.
data.at(0) = generator({});
}
return Status::OK();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/literal_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,9 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> layout;
} populate_data[] = {
{{}, {}},
{{0}, {0}},
{{16}, {0}},
{{2, 0}, {1, 0}},
{{4, 16}, {1, 0}},
{{21, 12}, {0, 1}},
{{6, 11, 17}, {2, 0, 1}},
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ cc_test(
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
Expand Down Expand Up @@ -1447,7 +1449,9 @@ cc_library(
hdrs = ["hlo_constant_folding.h"],
deps = [
":hlo",
":hlo_evaluator",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
Expand Down
237 changes: 32 additions & 205 deletions tensorflow/compiler/xla/service/hlo_constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,230 +24,57 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"

namespace xla {
namespace {

template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
static std::unique_ptr<Literal> ConvertIfTypesMatch(
const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
}

template <PrimitiveType primitive_src_type>
static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}

static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}

} // namespace

// ConstantFolderVisitor traverses the HLO computation and reduces certain
// constant graph sections, to literals.
class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}

Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;

Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override;

Status HandleReshape(HloInstruction* reshape) override;

Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;

Status HandleTranspose(HloInstruction* transpose) override;

// Returns whether a constant folding operation has occurred.
const bool changed() const { return changed_; }

// Runs the visitor on a computation and returns whether any changes were
// performed.
static StatusOr<bool> Run(HloComputation* computation);

private:
ConstantFolderVisitor() = default;

// Replaces the existing HLO instruction old_instruction, with a literal,
// and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithConstant(HloInstruction* old_instruction,
std::unique_ptr<Literal> literal) {
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
old_instruction, HloInstruction::CreateConstant(std::move(literal))));
changed_ = true;
return Status::OK();
}

// Whether any constant folding operations have occurred.
bool changed_ = false;
};

StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
ConstantFolderVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}

StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
auto evaluator = MakeUnique<HloEvaluator>();

XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto& comp : module->computations()) {
TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
changed = changed || result;
}
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
return changed;
}

Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(
auto reshaped_literal,
LiteralUtil::Reshape(reshape->operand(0)->literal(),
AsInt64Slice(reshape->shape().dimensions())));
return ReplaceWithConstant(reshape, std::move(reshaped_literal));
}
return Status::OK();
}

Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
auto transposed_literal = LiteralUtil::Transpose(
transpose->operand(0)->literal(), transpose->dimensions());
return ReplaceWithConstant(transpose, std::move(transposed_literal));
}
return Status::OK();
}
for (auto& computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->user_count() == 0 &&
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant and Parameter operation.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant) {
continue;
}
// Skip instructions with non-constant operands.
if (!hlo_query::AllOperandsAreConstants(*instruction)) {
continue;
}

Status ConstantFolderVisitor::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
if (operands[0]->opcode() == HloOpcode::kConstant) {
// If all the operands of a concatenate are constant, fold them into a
// single constant tensor.
// The result concatenate dimension is going to be the sum of all the
// concatenate dimensions of the arrays taking part of the operation.
int64 concat_dim = concatenate->dimensions()[0];
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
int64 rank = ShapeUtil::Rank(reference_shape);
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
if (concat_dim < 0) {
concat_dim += rank;
}
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
if (operands[i]->opcode() != HloOpcode::kConstant) {
return Status::OK();
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
if (result == nullptr) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
}
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}

auto literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
std::vector<int64> source_indices(rank, 0);
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), source_indices, literal.get(), dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
changed = true;
}
return ReplaceWithConstant(concatenate, std::move(literal));
}
return Status::OK();
}

Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
AsInt64Slice(shape.dimensions())));
TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
}
return Status::OK();
}

Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Literal& src_literal = operand->literal();
std::unique_ptr<Literal> new_constant =
ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
return ReplaceWithConstant(convert, std::move(new_constant));
}
return Status::OK();
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
return changed;
}

} // namespace xla
Loading

0 comments on commit 405f70c

Please sign in to comment.