[go: nahoru, domu]

Skip to content

Commit

Permalink
Back out "Fix to_glow onnx dumping issue"
Browse files Browse the repository at this point in the history
Summary: Original commit changeset: 16619c087875

Reviewed By: zrphercule

Differential Revision: D24909150

fbshipit-source-id: f5f64aa35f5c8e314918b52293efda91aabd449a
  • Loading branch information
jackm321 authored and facebook-github-bot committed Nov 12, 2020
1 parent d320f33 commit 3ebd873
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 77 deletions.
22 changes: 4 additions & 18 deletions torch_glow/src/CachingGraphRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ int64_t CachingGraphRunner::runOnJit(torch::jit::Stack &stack) {
std::lock_guard<std::mutex> guard(runJitLock);
bool temp = getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled;
getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled = false;
int64_t startTime;
startTime = TraceEvent::now();
int64_t startTime = TraceEvent::now();
ptGraphExecutor_.run(stack);
int64_t runTime = TraceEvent::now() - startTime;
getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled = temp;
Expand Down Expand Up @@ -415,12 +414,6 @@ Error CachingGraphRunner::runImpl(const PerGlowGraphInfo &info,
// Run the subgraph using JIT for comparison with Glow.
torch::jit::Stack copyStack;
if (settings.writeToOnnx || settings.jitVsGlowCompare) {

// We will use original graph for runOnJit, which means the first input
// should be module.
if (origGraph_ != nullptr) {
copyStack.push_back(module_);
}
for (auto &ival : stack) {
if (ival.isTensor()) {
copyStack.push_back(ival.deepcopy());
Expand Down Expand Up @@ -886,18 +879,11 @@ Error CachingGraphRunner::warmCache(const std::vector<InputMeta> &inputMeta,
CachingGraphRunner::CachingGraphRunner(
std::shared_ptr<torch::jit::Graph> graph,
std::shared_ptr<runtime::HostManager> hostManager,
PyTorchLoaderSettings defaultSettings, bool useRunOnly,
std::shared_ptr<torch::jit::Graph> origGraph, c10::IValue module)
: graph_(graph), origGraph_(origGraph), ptGraphExecutor_(graph, "forward"),
module_(module), hostManager_(hostManager),
PyTorchLoaderSettings defaultSettings, bool useRunOnly)
: graph_(graph), ptGraphExecutor_(graph, "forward"),
hostManager_(hostManager),
backend_(*EXIT_ON_ERR(hostManager->getBackend())),
defaultSettings_(std::move(defaultSettings)), useRunOnly_(useRunOnly) {

if (origGraph_ != nullptr) {
ptGraphExecutor_ = torch::jit::GraphExecutor(origGraph_, "forward");
} else {
ptGraphExecutor_ = torch::jit::GraphExecutor(graph_, "forward");
}
mergedTraceContext_ = glow::make_unique<TraceContext>(TraceLevel::STANDARD);
}

Expand Down
12 changes: 1 addition & 11 deletions torch_glow/src/CachingGraphRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,9 @@ class CachingGraphRunner {
/// for.
std::shared_ptr<torch::jit::Graph> graph_;

/// The PyTorch JIT Graph that this CachingGraphRunner caches for before
/// any preprocessing is done.
std::shared_ptr<torch::jit::Graph> origGraph_;

/// GraphExecutor used to execute graph_ on PyTorch for debugging purposes.
torch::jit::GraphExecutor ptGraphExecutor_;

/// The PyTorch module of the graph.
/// It is used as first input when running origGraph_ on JIT.
c10::IValue module_;

/// The HostManager used to store and run Glow graphs.
std::shared_ptr<runtime::HostManager> hostManager_;

Expand Down Expand Up @@ -191,9 +183,7 @@ class CachingGraphRunner {
public:
CachingGraphRunner(std::shared_ptr<torch::jit::Graph> graph,
std::shared_ptr<runtime::HostManager> hostManager,
PyTorchLoaderSettings settings, bool useRunOnly = false,
std::shared_ptr<torch::jit::Graph> origGraph = nullptr,
c10::IValue module = c10::IValue());
PyTorchLoaderSettings settings, bool useRunOnly = false);

~CachingGraphRunner();

Expand Down
103 changes: 55 additions & 48 deletions torch_glow/src/TorchGlowBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,60 @@ static Error ProcessPackedParams(torch::jit::Graph &graph,
return Error::success();
}

/// Implementation of to_backend preprocess method for Glow. \returns the
/// preprocessed Module if successful or an Error otherwise which is converted
/// to an exception for handling within PyTorch.
static Expected<torch::jit::Module>
preprocessImpl(torch::jit::Module origModule,
c10::impl::GenericDict method_compile_spec) {
// Preprocess each method
for (const auto &kv : method_compile_spec) {
const auto &methodName = kv.key().toStringRef();
auto method = origModule.get_method(methodName);
auto graph = method.graph();

GraphOutputType graphOutputType;
ASSIGN_VALUE_OR_RETURN_ERR(graphOutputType,
checkGraphInputsAndOutputs(*graph));

// Output lists no supported yet
if (graphOutputType == GraphOutputType::TENSOR_LIST) {
return MAKE_ERR("Tensor list output not supported.");
}

detail::fuseConcat(graph);
torch::jit::Inline(*graph);
RewriteQuantPackedParamOps(graph);
RETURN_IF_ERR(ProcessPackedParams(*graph, origModule._ivalue()));
}

// Freeze
auto preprocModule = torch::jit::freeze_module(origModule);

// Cleanup JIT graphs
for (const auto &kv : method_compile_spec) {
const auto &methodName = kv.key().toStringRef();
auto method = preprocModule.get_method(methodName);
auto graph = method.graph();
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
ConstantPooling(graph);
}

return preprocModule;
}

c10::IValue
TorchGlowBackend::preprocess(c10::IValue mod,
c10::impl::GenericDict method_compile_spec) {
// We do nothing in the preprocess, instead we do them in compile()
return mod;

torch::jit::Module origModule = mod.toModule();
origModule.eval();
auto resOrErr = preprocessImpl(origModule, method_compile_spec);
if (!resOrErr) {
throw std::runtime_error(ERR_TO_STRING(resOrErr.takeError()));
}
return resOrErr->_ivalue();
}

Error applySettingsOverrideFlagsToPyTorchLoaderSettings(
Expand Down Expand Up @@ -555,57 +604,17 @@ Error applyFuserSettingsToPyTorchLoaderSettings(
static Expected<std::unordered_map<
std::string, std::pair<std::unique_ptr<CachingGraphRunner>,
std::unique_ptr<JITGraphRunner>>>>
compileImpl(const torch::jit::Module &origModule,
compileImpl(const torch::jit::Module &module,
const c10::impl::GenericDict &method_compile_spec) {

std::unordered_map<std::string, std::pair<std::unique_ptr<CachingGraphRunner>,
std::unique_ptr<JITGraphRunner>>>
methodToRunnerMap;
std::unordered_map<std::string, std::shared_ptr<torch::jit::Graph>>
nameToOrigGraph;

for (const auto &kv : method_compile_spec) {
const auto &methodName = kv.key().toStringRef();
auto method = origModule.get_method(methodName);
auto graph = method.graph();
nameToOrigGraph[methodName] = graph->copy();

GraphOutputType graphOutputType;
ASSIGN_VALUE_OR_RETURN_ERR(graphOutputType,
checkGraphInputsAndOutputs(*graph));

// Output lists no supported yet
if (graphOutputType == GraphOutputType::TENSOR_LIST) {
return MAKE_ERR("Tensor list output not supported.");
}

detail::fuseConcat(graph);
torch::jit::Inline(*graph);
RewriteQuantPackedParamOps(graph);
RETURN_IF_ERR(ProcessPackedParams(*graph, origModule._ivalue()));
}

// Freeze
auto preprocModule = torch::jit::freeze_module(origModule);

// Cleanup JIT graphs
for (const auto &kv : method_compile_spec) {
const auto &methodName = kv.key().toStringRef();
auto method = preprocModule.get_method(methodName);
auto graph = method.graph();
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
ConstantPooling(graph);
}

// Compile each method
for (const auto &kv : method_compile_spec) {
const auto methodName = kv.key().toString()->string();
const auto &method = preprocModule.get_method(methodName);
auto it = nameToOrigGraph.find(methodName);
CHECK(it != nameToOrigGraph.end())
<< "Cannot find corresponding original graph for graph: " << methodName;
auto origGraph = it->second;
const auto &method = module.get_method(methodName);

const CompilationSpec &spec = *kv.value().toCustomClass<CompilationSpec>();
RETURN_IF_ERR(spec.validate());
Expand Down Expand Up @@ -648,7 +657,7 @@ compileImpl(const torch::jit::Module &origModule,

// Run fusion flow using JIT graph runner
std::unique_ptr<JITGraphRunner> runner = std::make_unique<JITGraphRunner>(
preprocModule._ivalue(), graph, baseSettings);
module._ivalue(), graph, baseSettings);
methodToRunnerMap.emplace(methodName,
std::make_pair(nullptr, std::move(runner)));
} else {
Expand All @@ -664,8 +673,7 @@ compileImpl(const torch::jit::Module &origModule,
graph,
glow::getHostManager(baseSettings.backendName,
baseSettings.numDevices),
baseSettings, /*useRunOnly*/ true, origGraph,
origModule._ivalue());
baseSettings, /*useRunOnly*/ true);

// Compile each compilation group
for (const auto &compilationGroup : spec.compilation_groups) {
Expand Down Expand Up @@ -696,7 +704,6 @@ TorchGlowBackend::compile(c10::IValue processed,
c10::impl::GenericDict method_compile_spec) {
auto module = processed.toModule().clone();

module.eval();
auto runnersOrErr = compileImpl(module, method_compile_spec);

if (!runnersOrErr) {
Expand Down

0 comments on commit 3ebd873

Please sign in to comment.