[go: nahoru, domu]

Skip to content

Commit

Permalink
Add hard swish implementation. (#5927)
Browse files Browse the repository at this point in the history
Summary:
Hard Swish nonlinearity in Glow is supported just for floating inputs, using lowering for the implementation. The current implementation supports both floating-point and quantized inputs. The floating-point implementation is made at run time, whereas the quantized one is based on pre-calculus, which is made at compile time using the implementation of Lookup Table from Glow.

Documentation:

Pull Request resolved: #5927

Test Plan: Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md.

Reviewed By: bertmaher

Differential Revision: D35241262

Pulled By: khabinov

fbshipit-source-id: 007ec81eb9f64db07cd263148ecdc15322620047
  • Loading branch information
EpureRares authored and facebook-github-bot committed Apr 1, 2022
1 parent 75e36ab commit b5efe5e
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/glow/Optimizer/GraphOptimizer/FunctionPasses.def
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ FUN_PASS(QuantizeSwish)
FUN_PASS(ConvertFullyConnectedToConvolution)
FUN_PASS(FoldMinMaxToClip)
FUN_PASS(ReplaceZeroScaleFP16QuantNodes)
FUN_PASS(ReplaceQuantizedHardSwishWithLookupTable)
FUN_PASS(FoldExpSumDivIntoSoftmax)
FUN_PASS(RemoveIdentityRelu)
FUN_PASS(RemoveIdentityClip)
Expand Down
9 changes: 9 additions & 0 deletions lib/Backends/CPU/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ bool CPUBackend::shouldLower(const Node *N) const {
case Kinded::Kind::ReluNodeKind:
case Kinded::Kind::ClipNodeKind:
case Kinded::Kind::LeakyReluNodeKind:
case Kinded::Kind::HardSwishNodeKind:
case Kinded::Kind::FullyConnectedNodeKind:
case Kinded::Kind::ConvolutionNodeKind:
case Kinded::Kind::SparseLengthsSumNodeKind:
Expand Down Expand Up @@ -157,6 +158,14 @@ bool CPUBackend::canDoIndexTypeDemotion(
return fromTy == ElemKind::Int64ITy && toTy == ElemKind::Int32ITy;
}

std::unique_ptr<FunctionPassPipeline>
CPUBackend::getOptimizationPipeline() const {
auto pipeline = Backend::getOptimizationPipeline();
pipeline->pushFront(
{FunctionPassID::ReplaceQuantizedHardSwishWithLookupTable});
return pipeline;
}

#if FACEBOOK_INTERNAL
llvm::ArrayRef<llvm::MemoryBufferRef> CPUBackend::getObjectRegistry() const {
return llvm::ArrayRef<llvm::MemoryBufferRef>();
Expand Down
3 changes: 3 additions & 0 deletions lib/Backends/CPU/CPUBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class CPUBackend : public LLVMBackend {
createIRGen(const IRFunction *IR,
AllocationsInfo &allocationsInfo) const override;

virtual std::unique_ptr<FunctionPassPipeline>
getOptimizationPipeline() const override;

protected:
virtual std::unique_ptr<CompiledFunction>
createCompiledFunction(std::unique_ptr<GlowJIT> JIT,
Expand Down
4 changes: 4 additions & 0 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,10 @@ void BoundInterpreterFunction::fwdLeakyReluInst(const LeakyReluInst *) {
DCHECK(!"Found LeakyReluInst but LeakyRelu is lowered on Interpreter");
}

void BoundInterpreterFunction::fwdHardSwishInst(const HardSwishInst *) {
DCHECK(!"Found HardSwishInst but HardSwish is lowered on Interpreter");
}

template <typename ElemTy>
void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) {
staticAssertFloatingPointType(ElemTy);
Expand Down
9 changes: 8 additions & 1 deletion lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,14 @@ static bool verifyEmbeddingBag(NodeValue dest, NodeValue data,
}

bool HardSwishNode::verify() const {
return checkSameType(getInput(), getResult(), this);
const NodeValue input = getInput();
const NodeValue result = getResult();
const Node *parent = result.getNode();
if (input.getType()->isQuantizedType()) {
return checkSameIsQuantized(input.getType(), result.getType(), parent) &&
checkSameShape(result, input, parent);
}
return checkSameType(result, input, parent);
}

bool PadNode::verify() const {
Expand Down
1 change: 1 addition & 0 deletions lib/LLVMIRCodeGen/LLVMBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ bool LLVMBackend::isOpSupported(const NodeInfo &NI) const {
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy});

case Kinded::Kind::HardSwishNodeKind:
case Kinded::Kind::AdaptiveAvgPoolNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

Expand Down
21 changes: 21 additions & 0 deletions lib/LLVMIRCodeGen/LLVMIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,27 @@ void LLVMIRGen::generateLLVMIRForDataParallelInstr(
break;
}

case Kinded::Kind::HardSwishInstKind: {
auto *HI = cast<HardSwishInst>(I);
auto *src = HI->getSrc();
auto *dest = HI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);

auto *F = getFunction("element_hard_swish", dest->getElementType());
llvm::CallInst *stackedOpCall = nullptr;
if (dest->getElementType() == ElemKind::FloatTy) {
stackedOpCall = createCall(builder, F, {loopCount, srcPtr});
} else {
LOG(FATAL) << "Type is not supported";
}
auto *elementTy = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}

case Kinded::Kind::ElementIsNaNInstKind: {
auto *AN = cast<ElementIsNaNInst>(I);
auto *src = AN->getSrc();
Expand Down
6 changes: 6 additions & 0 deletions lib/LLVMIRCodeGen/libjit/libjit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,12 @@ int8_t libjit_element_leaky_relu_i8(dim_t idx, const int8_t *src,
return libjit_clip_i8(scaledVal);
}

float libjit_element_hard_swish_f(dim_t idx, const float *src) {
float x = src[idx];
float relu6 = (x + 3) > 6 ? 6 : ((x + 3) < 0 ? 0 : (x + 3));
return x * relu6 / 6;
}

// When the LIBJIT compile option "-ffast-math" is enabled the intermediate
// computation expf(x) for Sigmoid operator is not handled properly for very
// large positive values which results in NaN values for the Sigmoid output.
Expand Down
30 changes: 30 additions & 0 deletions lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6454,6 +6454,36 @@ bool ReplaceZeroScaleFP16QuantNodes::run(Function *F,
return changed;
}

bool ReplaceQuantizedHardSwishWithLookupTable::run(
Function *F, const CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), getName());

bool changed = false;
for (auto &N : F->getNodes()) {
auto *HSN = dyn_cast<HardSwishNode>(&N);
CONTINUE_IF_NOT(HSN);

// Verify that input/output quantized.
NodeValue input = HSN->getInput();
NodeValue output = HSN->getResult();

CONTINUE_IF_NOT(input.getType()->isQuantizedType())
CONTINUE_IF_NOT(output.getType()->isQuantizedType())

// Replace HardSwish with LUT.
auto hard_swish_lambda = [](float a) {
int x = a + 3;
return a * (x > 6 ? 6 : ((x < 0) ? 0 : x)) / 6;
};
auto lookupTable = F->createIntLookupTable(
HSN->getName(), input, hard_swish_lambda, output.getType());
HSN->getResult().replaceAllUsesOfWith(lookupTable);
changed = true;
}

return changed;
}

/// This funciton uses TypeAToTypeBFunctionConverter to do a whole graph
/// demotion of Index type from INT64 to INT32.
static void transformIndexTypeDemotion(const Backend &B, Function *F,
Expand Down
12 changes: 12 additions & 0 deletions tools/ClassGen/InstrGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,18 @@ int main(int argc, char **argv) {
.autoVerify(VerifyKind::SameElementType, {"Dest", "Src"})
.autoIRGen();

BB.newInstr("HardSwish")
.addOperand("Dest", OperandKind::Out)
.addOperand("Src", OperandKind::In)
.inplaceOperand({
"Dest",
"Src",
})
.dataParallel()
.autoVerify(VerifyKind::SameShape, {"Dest", "Src"})
.autoVerify(VerifyKind::SameElementType, {"Dest", "Src"})
.autoIRGen();

BB.newInstr("SoftPlus")
.addOperand("Dest", OperandKind::Out)
.addOperand("Src", OperandKind::In)
Expand Down

0 comments on commit b5efe5e

Please sign in to comment.