diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 08791380a5e..42318252eb5 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -86,11 +86,98 @@ struct SimplifyLocals }; // a list of sinkables in a linear execution trace - using Sinkables = std::map; + using Sinkables = std::unordered_map; // locals in current linear execution trace, which we try to sink Sinkables sinkables; + // Reverse index: for each local L, tracks which sinkable keys have effects + // that read L. Used to find read-write conflicts when the current expression + // writes L. + std::unordered_map> localReadBySinkable; + + // Reverse index: for each local L, tracks which sinkable keys have effects + // that write L. A sinkable at key K always writes K, but may also write + // other locals if its value contains nested local.sets. + std::unordered_map> localWrittenBySinkable; + + // Sinkable keys whose effects include transfersControlFlow(). These must + // be invalidated whenever the current expression has side effects (including + // local writes), due to the asymmetric check in orderedBefore: + // sinkable.transfersControlFlow() && current.hasSideEffects() + // This set is usually empty since sinkables rarely transfer control flow. + std::unordered_set controlFlowSinkables; + + void registerSinkable(Index key) { + auto& effects = sinkables.at(key).effects; + for (auto L : effects.localsRead) { + localReadBySinkable[L].insert(key); + } + for (auto L : effects.localsWritten) { + localWrittenBySinkable[L].insert(key); + } + if (effects.transfersControlFlow()) { + controlFlowSinkables.insert(key); + } + } + + void unregisterSinkable(Index key) { + auto it = sinkables.find(key); + if (it == sinkables.end()) { + return; + } + auto& effects = it->second.effects; + for (auto L : effects.localsRead) { + auto mapIt = localReadBySinkable.find(L); + if (mapIt != localReadBySinkable.end()) { + mapIt->second.erase(key); + if (mapIt->second.empty()) { + localReadBySinkable.erase(mapIt); + } + } + } + for (auto L : effects.localsWritten) { + auto mapIt = localWrittenBySinkable.find(L); + if (mapIt != localWrittenBySinkable.end()) { + mapIt->second.erase(key); + if (mapIt->second.empty()) { + localWrittenBySinkable.erase(mapIt); + } + } + } + controlFlowSinkables.erase(key); + } + + void clearSinkables() { + sinkables.clear(); + localReadBySinkable.clear(); + localWrittenBySinkable.clear(); + controlFlowSinkables.clear(); + } + + Sinkables takeSinkables() { + localReadBySinkable.clear(); + localWrittenBySinkable.clear(); + controlFlowSinkables.clear(); + return std::move(sinkables); + } + + void eraseSinkable(typename Sinkables::iterator it) { + unregisterSinkable(it->first); + sinkables.erase(it); + } + + void eraseSinkable(Index key) { + unregisterSinkable(key); + sinkables.erase(key); + } + + void addSinkable(Index key, Expression** currp) { + sinkables.emplace(std::pair{ + key, SinkableInfo(currp, this->getPassOptions(), *this->getModule())}); + registerSinkable(key); + } + // Information about an exit from a block: the break, and the // sinkables. For the final exit from a block (falling off) // exitter is null. @@ -135,8 +222,7 @@ struct SimplifyLocals // value means the block already has a return value self->unoptimizableBlocks.insert(br->name); } else { - self->blockBreaks[br->name].push_back( - {currp, std::move(self->sinkables)}); + self->blockBreaks[br->name].push_back({currp, self->takeSinkables()}); } } else if (curr->is()) { return; // handled in visitBlock @@ -153,7 +239,7 @@ struct SimplifyLocals } // TODO: we could use this info to stop gathering data on these blocks } - self->sinkables.clear(); + self->clearSinkables(); } static void doNoteIfCondition( @@ -161,7 +247,7 @@ struct SimplifyLocals Expression** currp) { // we processed the condition of this if-else, and now control flow branches // into either the true or the false sides - self->sinkables.clear(); + self->clearSinkables(); } static void @@ -170,13 +256,13 @@ struct SimplifyLocals auto* iff = (*currp)->cast(); if (iff->ifFalse) { // We processed the ifTrue side of this if-else, save it on the stack. - self->ifStack.push_back(std::move(self->sinkables)); + self->ifStack.push_back(self->takeSinkables()); } else { // This is an if without an else. if (allowStructure) { self->optimizeIfReturn(iff, currp); } - self->sinkables.clear(); + self->clearSinkables(); } } @@ -191,7 +277,7 @@ struct SimplifyLocals self->optimizeIfElseReturn(iff, currp, self->ifStack.back()); } self->ifStack.pop_back(); - self->sinkables.clear(); + self->clearSinkables(); } void visitBlock(Block* curr) { @@ -204,13 +290,13 @@ struct SimplifyLocals // post-block cleanups if (curr->name.is()) { if (unoptimizableBlocks.contains(curr->name)) { - sinkables.clear(); + clearSinkables(); unoptimizableBlocks.erase(curr->name); } if (hasBreaks) { // more than one path to here, so nonlinear - sinkables.clear(); + clearSinkables(); blockBreaks.erase(curr->name); } } @@ -284,7 +370,7 @@ struct SimplifyLocals // reuse the local.get that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); - sinkables.erase(found); + eraseSinkable(found); anotherCycle = true; } } @@ -300,7 +386,65 @@ struct SimplifyLocals } void checkInvalidations(EffectAnalyzer& effects) { - // TODO: this is O(bad) + // Fast path: if the current expression only accesses locals (no memory, + // calls, globals, traps, control flow, etc.), we can use reverse indices + // to find conflicting sinkables in O(|locals touched|) instead of + // iterating all sinkables. + // + // Each condition below corresponds to a non-local conflict category in + // EffectAnalyzer::orderedBefore. When all are false, the only remaining + // conflict paths are through local variable read/write pairs, PLUS the + // asymmetric check: sinkable.transfersControlFlow() && + // current.hasSideEffects(). The latter is handled via + // controlFlowSinkables below. + if (!effects.transfersControlFlow() && !effects.writesGlobalState() && + !effects.readsMutableGlobalState() && !effects.danglingPop && + !effects.trap && !effects.hasSynchronization() && + !effects.mayNotReturn) { + std::unordered_set candidates; + // When the current expression reads local L, any sinkable that writes L + // has a write-read conflict. + for (auto L : effects.localsRead) { + auto it = localWrittenBySinkable.find(L); + if (it != localWrittenBySinkable.end()) { + candidates.insert(it->second.begin(), it->second.end()); + } + } + // When the current expression writes local L, any sinkable that reads L + // (read-write conflict) or writes L (write-write conflict) is a + // candidate. + for (auto L : effects.localsWritten) { + auto it = localReadBySinkable.find(L); + if (it != localReadBySinkable.end()) { + candidates.insert(it->second.begin(), it->second.end()); + } + auto it2 = localWrittenBySinkable.find(L); + if (it2 != localWrittenBySinkable.end()) { + candidates.insert(it2->second.begin(), it2->second.end()); + } + } + // Handle the asymmetric orderedBefore check: a sinkable that transfers + // control flow conflicts with any expression that has side effects + // (which includes local writes). This set is usually empty. + if (effects.hasSideEffects() && !controlFlowSinkables.empty()) { + candidates.insert(controlFlowSinkables.begin(), + controlFlowSinkables.end()); + } + std::vector invalidated; + for (auto key : candidates) { + auto it = sinkables.find(key); + if (it != sinkables.end() && effects.orderedAfter(it->second.effects)) { + invalidated.push_back(key); + } + } + for (auto key : invalidated) { + eraseSinkable(key); + } + return; + } + + // Slow path: the expression has non-local effects, so we must check all + // sinkables. std::vector invalidated; for (auto& [index, info] : sinkables) { if (effects.orderedAfter(info.effects)) { @@ -308,7 +452,7 @@ struct SimplifyLocals } } for (auto index : invalidated) { - sinkables.erase(index); + eraseSinkable(index); } } @@ -334,7 +478,7 @@ struct SimplifyLocals } } for (auto index : invalidated) { - self->sinkables.erase(index); + self->eraseSinkable(index); } } @@ -419,7 +563,7 @@ struct SimplifyLocals Drop* drop = ExpressionManipulator::convert(previous); drop->value = previousValue; drop->finalize(); - self->sinkables.erase(found); + self->eraseSinkable(found); self->anotherCycle = true; } } @@ -432,9 +576,7 @@ struct SimplifyLocals if (set && self->canSink(set)) { Index index = set->index; assert(!self->sinkables.contains(index)); - self->sinkables.emplace(std::pair{ - index, - SinkableInfo(currp, self->getPassOptions(), *self->getModule())}); + self->addSinkable(index, currp); } if (!allowNesting) { @@ -476,7 +618,13 @@ struct SimplifyLocals if (sinkables.empty()) { return; } - Index goodIndex = sinkables.begin()->first; + // Pick the lowest-index sinkable for deterministic output. + Index goodIndex = std::min_element(sinkables.begin(), + sinkables.end(), + [](const auto& a, const auto& b) { + return a.first < b.first; + }) + ->first; // Ensure we have a place to write the return values for, if not, we // need another cycle. auto* block = loop->body->dynCast(); @@ -498,7 +646,7 @@ struct SimplifyLocals this->replaceCurrent(set); // We moved things around, clear all tracking; we'll do another cycle // anyhow. - sinkables.clear(); + clearSinkables(); anotherCycle = true; } @@ -515,7 +663,8 @@ struct SimplifyLocals // block does not already have a return value (if one break has one, they // all do) assert(!(*breaks[0].brp)->template cast()->value); - // look for a local.set that is present in them all + // look for a local.set that is present in them all. + // Pick the lowest index for deterministic output. bool found = false; Index sharedIndex = -1; for (auto& [index, _] : sinkables) { @@ -526,10 +675,9 @@ struct SimplifyLocals break; } } - if (inAll) { + if (inAll && (!found || index < sharedIndex)) { sharedIndex = index; found = true; - break; } } if (!found) { @@ -624,7 +772,7 @@ struct SimplifyLocals auto* newLocalSet = Builder(*this->getModule()).makeLocalSet(sharedIndex, block); this->replaceCurrent(newLocalSet); - sinkables.clear(); + clearSinkables(); anotherCycle = true; block->finalize(); } @@ -656,27 +804,35 @@ struct SimplifyLocals Sinkables& ifFalse = sinkables; Index goodIndex = -1; bool found = false; + auto pickLowest = [](Sinkables& s) { + return std::min_element( + s.begin(), + s.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }) + ->first; + }; if (iff->ifTrue->type == Type::unreachable) { // since the if type is none assert(iff->ifFalse->type != Type::unreachable); if (!ifFalse.empty()) { - goodIndex = ifFalse.begin()->first; + goodIndex = pickLowest(ifFalse); found = true; } } else if (iff->ifFalse->type == Type::unreachable) { // since the if type is none assert(iff->ifTrue->type != Type::unreachable); if (!ifTrue.empty()) { - goodIndex = ifTrue.begin()->first; + goodIndex = pickLowest(ifTrue); found = true; } } else { - // Look for a shared index. + // Look for a shared index (pick the lowest for determinism). for (auto& [index, _] : ifTrue) { if (ifFalse.contains(index)) { - goodIndex = index; - found = true; - break; + if (!found || index < goodIndex) { + goodIndex = index; + found = true; + } } } } @@ -799,7 +955,13 @@ struct SimplifyLocals // element). // // TODO investigate more - Index goodIndex = sinkables.begin()->first; + // Pick the lowest-index sinkable for deterministic output. + Index goodIndex = std::min_element(sinkables.begin(), + sinkables.end(), + [](const auto& a, const auto& b) { + return a.first < b.first; + }) + ->first; auto localType = this->getFunction()->getLocalType(goodIndex); if (!localType.isDefaultable()) { return; @@ -973,7 +1135,7 @@ struct SimplifyLocals anotherCycle = true; } // clean up - sinkables.clear(); + clearSinkables(); blockBreaks.clear(); unoptimizableBlocks.clear(); return anotherCycle;