dlvhex
2.5.0
|
00001 /* dlvhex -- Answer-Set Programming with external interfaces. 00002 * Copyright (C) 2005-2007 Roman Schindlauer 00003 * Copyright (C) 2006-2015 Thomas Krennwallner 00004 * Copyright (C) 2009-2016 Peter Schüller 00005 * Copyright (C) 2011-2016 Christoph Redl 00006 * Copyright (C) 2015-2016 Tobias Kaminski 00007 * Copyright (C) 2015-2016 Antonius Weinzierl 00008 * 00009 * This file is part of dlvhex. 00010 * 00011 * dlvhex is free software; you can redistribute it and/or modify it 00012 * under the terms of the GNU Lesser General Public License as 00013 * published by the Free Software Foundation; either version 2.1 of 00014 * the License, or (at your option) any later version. 00015 * 00016 * dlvhex is distributed in the hope that it will be useful, but 00017 * WITHOUT ANY WARRANTY; without even the implied warranty of 00018 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00019 * Lesser General Public License for more details. 00020 * 00021 * You should have received a copy of the GNU Lesser General Public 00022 * License along with dlvhex; if not, write to the Free Software 00023 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 00024 * 02110-1301 USA. 00025 */ 00026 00034 #ifdef HAVE_CONFIG_H 00035 #include "config.h" 00036 #endif // HAVE_CONFIG_H 00037 00038 //#define BOOST_SPIRIT_DEBUG 00039 00040 #include "dlvhex2/AggregatePlugin.h" 00041 #include "dlvhex2/PlatformDefinitions.h" 00042 #include "dlvhex2/ProgramCtx.h" 00043 #include "dlvhex2/Registry.h" 00044 #include "dlvhex2/Printer.h" 00045 #include "dlvhex2/Printhelpers.h" 00046 #include "dlvhex2/PredicateMask.h" 00047 #include "dlvhex2/Logger.h" 00048 #include "dlvhex2/HexParser.h" 00049 #include "dlvhex2/HexParserModule.h" 00050 #include "dlvhex2/HexGrammar.h" 00051 #include "dlvhex2/ExternalLearningHelper.h" 00052 00053 #include <boost/algorithm/string/predicate.hpp> 00054 #include <boost/lexical_cast.hpp> 00055 00056 DLVHEX_NAMESPACE_BEGIN 00057 00058 AggregatePlugin::CtxData::CtxData(): 00059 enabled(false), mode(ExtRewrite) 00060 { 00061 } 00062 00063 00064 AggregatePlugin::AggregatePlugin(): 00065 PluginInterface() 00066 { 00067 setNameVersion("dlvhex-aggregateplugin[internal]", 2, 0, 0); 00068 } 00069 00070 00071 AggregatePlugin::~AggregatePlugin() 00072 { 00073 } 00074 00075 00076 // output help message for this plugin 00077 void AggregatePlugin::printUsage(std::ostream& o) const 00078 { 00079 // 123456789-123456789-123456789-123456789-123456789-123456789-123456789-123456789- 00080 o << " --aggregate-enable[=true,false]" << std::endl 00081 << " Enable aggregate plugin (default is enabled)." << std::endl; 00082 o << " --aggregate-mode=[native,ext,extbl]" << std::endl 00083 << " native (default) : Keep aggregates" << std::endl 00084 << " (but simplify them to some basic types)" << std::endl 00085 << " ext : Rewrite aggregates to external atoms" << std::endl 00086 << " extbl : Rewrite aggregates to boolean external atoms" << std::endl 00087 // << " --aggregate-allowrecaggregates" << std::endl 00088 // << " Allows cycles through aggregates." << std::endl 00089 // << " Depending on the solver backend, this might lead to" << std::endl 00090 // << " different results." << std::endl 00091 // << " With --aggregate-mode=ext, the option is irrelevant" << std::endl 00092 // << " as aggregates are replaced by external atoms." << std::endl 00093 << " --aggregate-allowaggextcycles" << std::endl 00094 << " Allows cycles which involve both aggregates and" << std::endl 00095 << " external atoms. If the option is not specified," << std::endl 00096 << " such cycles lead to abortion; if specified, only" << std::endl 00097 << " a warning is printed but the models might be not minimal." << std::endl 00098 << " With --aggregate-mode=ext, the option is irrelevant" << std::endl 00099 << " as aggregates are replaced by external atoms (models will be minimal in that case)." << std::endl 00100 << " See examples/aggextcycle1.hex."; 00101 } 00102 00103 00104 // accepted options: --higherorder-enable 00105 // 00106 // processes options for this plugin, and removes recognized options from pluginOptions 00107 // (do not free the pointers, the const char* directly come from argv) 00108 void AggregatePlugin::processOptions( 00109 std::list<const char*>& pluginOptions, 00110 ProgramCtx& ctx) 00111 { 00112 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 00113 ctxdata.enabled = true; 00114 ctxdata.mode = CtxData::Simplify; 00115 00116 // we always support it 00117 ctx.config.setOption("AllowAggCycles", 1); 00118 00119 typedef std::list<const char*>::iterator Iterator; 00120 Iterator it; 00121 WARNING("create (or reuse, maybe from potassco?) cmdline option processing facility") 00122 it = pluginOptions.begin(); 00123 while( it != pluginOptions.end() ) { 00124 bool processed = false; 00125 const std::string str(*it); 00126 if( boost::starts_with(str, "--aggregate-enable" ) ) { 00127 std::string m = str.substr(std::string("--aggregate-enable").length()); 00128 if (m == "" || m == "=true") { 00129 ctxdata.enabled = true; 00130 } 00131 else if (m == "=false") { 00132 ctxdata.enabled = false; 00133 } 00134 else { 00135 std::stringstream ss; 00136 ss << "Unknown --aggregate-enable option: " << m; 00137 throw PluginError(ss.str()); 00138 } 00139 processed = true; 00140 } 00141 else if( boost::starts_with(str, "--aggregate-mode=") ) { 00142 std::string m = str.substr(std::string("--aggregate-mode=").length()); 00143 if (m == "ext") { 00144 ctxdata.mode = CtxData::ExtRewrite; 00145 } 00146 else if (m == "extbl") { 00147 ctxdata.mode = CtxData::ExtBlRewrite; 00148 } 00149 // "native" was previously called "simplify" --> keep it for backwards compatibility 00150 else if (m == "native" || m == "simplify") { 00151 ctxdata.mode = CtxData::Simplify; 00152 } 00153 else { 00154 std::stringstream ss; 00155 ss << "Unknown --aggregate-mode option: " << m; 00156 throw PluginError(ss.str()); 00157 } 00158 processed = true; 00159 } 00160 // else if( str == "--aggregate-allowrecaggregates" ) 00161 // { 00162 // ctx.config.setOption("AllowAggCycles", 1); 00163 // processed = true; 00164 // } 00165 else if( str == "--aggregate-allowaggextcycles" ) { 00166 ctx.config.setOption("AllowAggExtCycles", 1); 00167 processed = true; 00168 } 00169 00170 if( processed ) { 00171 // return value of erase: element after it, maybe end() 00172 DBGLOG(DBG,"AggregatePlugin successfully processed option " << str); 00173 it = pluginOptions.erase(it); 00174 } 00175 else { 00176 it++; 00177 } 00178 } 00179 } 00180 00181 00182 namespace 00183 { 00184 00185 typedef AggregatePlugin::CtxData CtxData; 00186 00187 class AggregateRewriter: 00188 public PluginRewriter 00189 { 00190 private: 00191 AggregatePlugin::CtxData& ctxdata; 00192 InterpretationPtr newEdb; 00193 std::vector<ID> newIdb; 00194 int ruleNr; 00195 void rewriteRule(ProgramCtx& ctx, InterpretationPtr edb, std::vector<ID>& idb, const Rule& rule); 00196 00197 std::string aggregateFunctionToExternalAtomName(ID aggFunction); 00198 public: 00199 AggregateRewriter(AggregatePlugin::CtxData& ctxdata) : ctxdata(ctxdata), ruleNr(0) {} 00200 virtual ~AggregateRewriter() {} 00201 00202 virtual void prepareRewrittenProgram(InterpretationPtr newEdb, ProgramCtx& ctx); 00203 virtual void rewrite(ProgramCtx& ctx); 00204 }; 00205 00206 std::string AggregateRewriter::aggregateFunctionToExternalAtomName(ID aggFunction) { 00207 00208 DBGLOG(DBG, "Translating aggregate function " << aggFunction); 00209 switch(aggFunction.address) { 00210 case ID::TERM_BUILTIN_AGGCOUNT: return "count"; 00211 case ID::TERM_BUILTIN_AGGMIN: return "min"; 00212 case ID::TERM_BUILTIN_AGGMAX: return "max"; 00213 case ID::TERM_BUILTIN_AGGSUM: return "sum"; 00214 case ID::TERM_BUILTIN_AGGTIMES: return "times"; 00215 case ID::TERM_BUILTIN_AGGAVG: return "avg"; 00216 // case ID::TERM_BUILTIN_AGGANY: return "any"; 00217 default: assert(false); return ""; 00218 } 00219 } 00220 00221 namespace 00222 { 00223 void warnMaxint(const ProgramCtx& ctx, ID term) { 00224 static bool warned = false; 00225 //LOG(WARNING,"AggregatePlugin term is " << term); 00226 if( !warned && term.isIntegerTerm() && 00227 (ctx.maxint == ID_FAIL || term.address > ctx.maxint) ) { 00228 LOG(WARNING,"AggregatePlugin requires --maxint or -N to be set to a sufficiently high value! (" << term.address << "/" << ctx.maxint << ")"); 00229 warned = true; 00230 } 00231 } 00232 } 00233 00234 void AggregateRewriter::rewriteRule(ProgramCtx& ctx, InterpretationPtr edb, std::vector<ID>& idb, const Rule& rule) { 00235 00236 RegistryPtr reg = ctx.registry(); 00237 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 00238 00239 // take the rule head as it is 00240 Rule newRule = rule; 00241 newRule.body.clear(); 00242 00243 // determine a prefix which does not occur at the beginning of any variable in the rule's body 00244 std::string prefix = "F";// function value 00245 std::set<ID> vars; 00246 BOOST_FOREACH (ID b, rule.body) { 00247 reg->getVariablesInID(b, vars); 00248 } 00249 BOOST_FOREACH (ID v, vars) { 00250 std::string currentVar = reg->terms.getByID(v).getUnquotedString(); 00251 while (prefix.length() <= currentVar.length() && 00252 currentVar.substr(0, prefix.length()) == prefix) { 00253 prefix = prefix + "F"; 00254 } 00255 } 00256 00257 // find all top-level aggregates in the rule body 00258 DBGLOG(DBG, "Rewriting aggregate atoms in rule"); 00259 int aggIndex = 0; 00260 BOOST_FOREACH (ID b, rule.body) { 00261 if (b.isAggregateAtom()) { 00262 int symbolicSetSize = -1; 00263 DBGLOG(DBG, "Rewriting aggregate atom " << printToString<RawPrinter>(b, reg)); 00264 const AggregateAtom& aatom = reg->aatoms.getByID(b); 00265 00266 // Variables of the symbolic set which occur also in the remaining rule body 00267 std::vector<ID> symbolicSetVarsIntersectingRemainingBody; 00268 00269 // in the following we need to unique predicates for this aggregate 00270 ID keyPredID = reg->getAuxiliaryConstantSymbol('g', ID::termFromInteger(ruleNr++)); 00271 ID inputPredID = reg->getAuxiliaryConstantSymbol('g', ID::termFromInteger(ruleNr++)); 00272 00273 // For ;-separated aggregate elements from the ASP-Core 2 standard. 00274 // 00275 // We need to iterate either through aatom.mvariables or, if the former is empty, through aatom.variables (resp. aatom.mliterals or aatom.literals). 00276 // Trick: Iterate through aatom.mvariables plus one additional index for aatom.variables (resp. literals). 00277 // If the additional index is reached and is 0, then we bind the reference to aatom.variables instead of aatom.mvariables; 00278 // if it is greater than 0 we skip it because aatom.mvariables is nonempty. 00279 DBGLOG(DBG, "Found " << aatom.mvariables.size() << " multi-symbolic sets"); 00280 00281 // first of all, analyze the aggregate and build needed sets of variables 00282 std::vector<std::set<ID> > symSetVars; 00283 for (int symbSetIndex = 0; symbSetIndex <= aatom.mvariables.size(); ++symbSetIndex) { 00284 if (symbSetIndex == aatom.mvariables.size() && symbSetIndex > 0) continue; 00285 00286 DBGLOG(DBG, "Processing symbolic set number " << symbSetIndex); 00287 const Tuple& currentSymbolicSetVars = (symbSetIndex == aatom.mvariables.size() ? aatom.variables : aatom.mvariables[symbSetIndex]); 00288 const Tuple& currentSymbolicSetLiterals = (symbSetIndex == aatom.mliterals.size() ? aatom.literals : aatom.mliterals[symbSetIndex]); 00289 00290 // determine size of the tuples in the symbolic set 00291 if (symbolicSetSize != -1 && symbolicSetSize != currentSymbolicSetVars.size()) throw GeneralError("Symbolic set of aggregate \"" + printToString<RawPrinter>(b, reg) + "\" contains tuples of varying sizes"); 00292 symbolicSetSize = currentSymbolicSetVars.size(); 00293 00294 // collect all variables from the conjunction of the symbolic set 00295 DBGLOG(DBG, "Harvesting variables in literals of the symbolic set"); 00296 symSetVars.push_back(std::set<ID>()); 00297 std::set<ID>& currentSymSetVars = symSetVars[symSetVars.size() - 1]; 00298 BOOST_FOREACH (ID c, currentSymbolicSetVars) { currentSymSetVars.insert(c); } 00299 BOOST_FOREACH (ID cs, currentSymbolicSetLiterals) { 00300 DBGLOG(DBG, "Harvesting variables in literal of the symbolic set: " << printToString<RawPrinter>(cs, reg)); 00301 reg->getVariablesInID(cs, currentSymSetVars); 00302 } 00303 DBGLOG(DBG, "Symbolic set uses " << currentSymSetVars.size() << " variables"); 00304 00305 // collect all variables from the remaining body of the rule 00306 DBGLOG(DBG, "Harvesting variables in remaining rule body"); 00307 std::set<ID> bodyVars; 00308 BOOST_FOREACH (ID rb, rule.body) { 00309 if (rb != b) { 00310 // exclude local variables in other aggregates but keep the bound variables thereof 00311 if (rb.isAggregateAtom()) { 00312 const AggregateAtom& ag2 = reg->aatoms.getByID(rb); 00313 if (ag2.tuple[0] != ID_FAIL) reg->getVariablesInID(ag2.tuple[0], bodyVars); 00314 if (ag2.tuple[4] != ID_FAIL) reg->getVariablesInID(ag2.tuple[4], bodyVars); 00315 } 00316 else { 00317 DBGLOG(DBG, "Harvesting variables in " << printToString<RawPrinter>(rb, reg)); 00318 reg->getVariablesInID(rb, bodyVars); 00319 } 00320 } 00321 } 00322 00323 // collect all variables of the symbolic set which occur also in the remaining rule body 00324 DBGLOG(DBG, "Harvesting variables shared between symbolic set and remaining rule body"); 00325 BOOST_FOREACH (ID c, currentSymSetVars) { 00326 if (std::find(bodyVars.begin(), bodyVars.end(), c) != bodyVars.end()) { 00327 DBGLOG(DBG, "Body variable of symbolic set: " << printToString<RawPrinter>(c, reg)); 00328 symbolicSetVarsIntersectingRemainingBody.push_back(c); 00329 } 00330 } 00331 } 00332 00333 // same trick again: now construct key and input rules 00334 for (int symbSetIndex = 0; symbSetIndex <= aatom.mvariables.size(); ++symbSetIndex) { 00335 if (symbSetIndex == aatom.mvariables.size() && symbSetIndex > 0) continue; 00336 00337 DBGLOG(DBG, "Processing symbolic set number " << symbSetIndex); 00338 const Tuple& currentSymbolicSetVars = (symbSetIndex == aatom.mvariables.size() ? aatom.variables : aatom.mvariables[symbSetIndex]); 00339 const Tuple& currentSymbolicSetLiterals = (symbSetIndex == aatom.mliterals.size() ? aatom.literals : aatom.mliterals[symbSetIndex]); 00340 00341 Rule keyRule(ID::MAINKIND_RULE); 00342 Rule inputRule(ID::MAINKIND_RULE); 00343 { 00344 // Construct the external atom key rule of the following type: 00345 // A. Single head atom 00346 // 1. create a unique predicate name p 00347 // 2. for all variables X from the conjunction of the symbolic set: 00348 // if X occurs also in the remaining body of the rule 00349 // add it to the tuple of p 00350 // B. The body consists of all atoms of the original rule except the aggregate being rewritten 00351 00352 // construct the input rule 00353 DBGLOG(DBG, "Constructing key rule"); 00354 OrdinaryAtom oatom(ID::MAINKIND_ATOM | ID::PROPERTY_AUX); 00355 if (symbolicSetVarsIntersectingRemainingBody.size() > 0) { 00356 oatom.kind |= ID::SUBKIND_ATOM_ORDINARYN; 00357 } 00358 else { 00359 oatom.kind |= ID::SUBKIND_ATOM_ORDINARYG; 00360 } 00361 00362 // A. 00363 DBGLOG(DBG, "Constructing key rule head"); 00364 // 1. 00365 oatom.tuple.push_back(keyPredID); 00366 // 2. 00367 BOOST_FOREACH (ID var, symbolicSetVarsIntersectingRemainingBody) { 00368 oatom.tuple.push_back(var); 00369 } 00370 keyRule.head.push_back(reg->storeOrdinaryAtom(oatom)); 00371 // B. 00372 DBGLOG(DBG, "Constructing key rule body"); 00373 BOOST_FOREACH (ID bb, rule.body) { 00374 // remove range comparisons of the aggregate value (this will not destroy safety) 00375 if (bb.isBuiltinAtom()){ 00376 const AggregateAtom& aatom = reg->aatoms.getByID(b); 00377 const BuiltinAtom& batom = reg->batoms.getByID(bb); 00378 if ( // check if the builtin is a range comparison 00379 ( batom.tuple[0].address == ID::TERM_BUILTIN_LT || batom.tuple[0].address == ID::TERM_BUILTIN_LE 00380 || batom.tuple[0].address == ID::TERM_BUILTIN_GT || batom.tuple[0].address == ID::TERM_BUILTIN_GE 00381 || batom.tuple[0].address == ID::TERM_BUILTIN_NE) 00382 // check if it compares the aggregate value 00383 && ( (aatom.tuple[1].address == ID::TERM_BUILTIN_EQ && aatom.tuple[0].isVariableTerm() && (batom.tuple[1] == aatom.tuple[0] || batom.tuple[2] == aatom.tuple[0])) 00384 || (aatom.tuple[3].address == ID::TERM_BUILTIN_EQ && aatom.tuple[4].isVariableTerm() && (batom.tuple[1] == aatom.tuple[4] || batom.tuple[2] == aatom.tuple[4])) ) 00385 ) { 00386 continue; 00387 } 00388 } 00389 00390 if (bb == b) { 00391 // make sure that we do not lose safety: if b _defines_ a variable, then define it as an arbitrary integer insteads 00392 const AggregateAtom& aatom = reg->aatoms.getByID(b); 00393 if (aatom.tuple[1].address == ID::TERM_BUILTIN_EQ && aatom.tuple[0].isVariableTerm()) { 00394 BuiltinAtom bi(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_BUILTIN); 00395 bi.tuple.push_back(ID::termFromBuiltin(ID::TERM_BUILTIN_INT)); 00396 bi.tuple.push_back(aatom.tuple[0]); 00397 keyRule.body.push_back(ID::posLiteralFromAtom(reg->batoms.storeAndGetID(bi))); 00398 } 00399 if (aatom.tuple[3].address == ID::TERM_BUILTIN_EQ && aatom.tuple[4].isVariableTerm()) { 00400 BuiltinAtom bi(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_BUILTIN); 00401 bi.tuple.push_back(ID::termFromBuiltin(ID::TERM_BUILTIN_INT)); 00402 bi.tuple.push_back(aatom.tuple[4]); 00403 keyRule.body.push_back(ID::posLiteralFromAtom(reg->batoms.storeAndGetID(bi))); 00404 } 00405 continue; 00406 } 00407 if (bb.isExternalAtom()) keyRule.kind |= ID::PROPERTY_RULE_EXTATOMS; 00408 keyRule.body.push_back(bb); 00409 } 00410 } 00411 00412 { 00413 // Construct the external atom predicate input by a rule of the following type: 00414 // A. Single head atom 00415 // 1. create a unique predicate name p 00416 // 2. for all variables X from the disjunction of the symbolic set: 00417 // if X occurs also in the remaining body of the rule 00418 // add it to the tuple of p 00419 // 3. add all variables of the symbolic set to the tuple of p 00420 // 4. (optional) for encoding "extbl", add index of the current symbolic set element 00421 // 5. (optional) for encoding "extbl", add all variables in the conjunction of the symbolic set 00422 // B. The body consists of: 00423 // 1. the conjunction of the symbolic set 00424 // 2. key head 00425 00426 // construct the input rule 00427 DBGLOG(DBG, "Constructing input rule"); 00428 OrdinaryAtom oatom(ID::MAINKIND_ATOM | ID::PROPERTY_AUX); 00429 if (symbolicSetVarsIntersectingRemainingBody.size() > 0 || currentSymbolicSetVars.size() > 0) { 00430 oatom.kind |= ID::SUBKIND_ATOM_ORDINARYN; 00431 } 00432 else { 00433 oatom.kind |= ID::SUBKIND_ATOM_ORDINARYG; 00434 } 00435 // A. 00436 DBGLOG(DBG, "Constructing input rule head"); 00437 // 1. 00438 oatom.tuple.push_back(inputPredID); 00439 // 2. 00440 BOOST_FOREACH (ID var, symbolicSetVarsIntersectingRemainingBody) { 00441 oatom.tuple.push_back(var); 00442 } 00443 // 3. 00444 for (int i = 0; i < currentSymbolicSetVars.size(); i++) { 00445 oatom.tuple.push_back(currentSymbolicSetVars[i]); 00446 } 00447 // extbl 00448 if (ctxdata.mode == AggregatePlugin::CtxData::ExtBlRewrite) { 00449 // 4. 00450 DBGLOG(DBG, "Adding symbolic set index " << symbSetIndex); 00451 oatom.tuple.push_back(ID::termFromInteger(symbSetIndex)); 00452 // 5. 00453 DBGLOG(DBG, "Adding " << symSetVars[symbSetIndex].size() << " variables to input rule head"); 00454 for (std::set<ID>::iterator it = symSetVars[symbSetIndex].begin(); it != symSetVars[symbSetIndex].end(); ++it) { 00455 DBGLOG(DBG, "Adding " << printToString<RawPrinter>(*it, reg)); 00456 oatom.tuple.push_back(*it); 00457 } 00458 oatom.tuple.push_back(ID::termFromInteger(symSetVars[symbSetIndex].size())); 00459 } 00460 inputRule.head.push_back(reg->storeOrdinaryAtom(oatom)); 00461 // B. 00462 DBGLOG(DBG, "Constructing input rule body"); 00463 // 1. 00464 inputRule.body = currentSymbolicSetLiterals; 00465 BOOST_FOREACH (ID l, currentSymbolicSetLiterals) { 00466 if (l.isExternalAtom()) inputRule.kind |= ID::PROPERTY_RULE_EXTATOMS; 00467 } 00468 // 2. 00469 inputRule.body.push_back(ID::posLiteralFromAtom(keyRule.head[0])); 00470 } 00471 00472 // recursively handle aggregates in the key and the value rule 00473 DBGLOG(DBG, "Recursive call for " << printToString<RawPrinter>(reg->storeRule(keyRule), reg)); 00474 rewriteRule(ctx, edb, idb, keyRule); 00475 DBGLOG(DBG, "Recursive call for " << printToString<RawPrinter>(reg->storeRule(inputRule), reg)); 00476 rewriteRule(ctx, edb, idb, inputRule); 00477 00478 // add reversed key and value rules 00479 if (ctxdata.mode == AggregatePlugin::CtxData::ExtBlRewrite) { 00480 for (int r = 2; r <= 2; ++r){ 00481 const Rule kvrule = (r == 1 ? keyRule : inputRule); 00482 BOOST_FOREACH (ID b, kvrule.body) { 00483 if (!b.isOrdinaryAtom()) { 00484 DBGLOG(DBG, "Skipping non-ordinary literal " << printToString<RawPrinter>(b, reg) << " in reversed rule"); 00485 continue; 00486 } 00487 if (!b.isNaf()) { 00488 Rule rev = kvrule; 00489 rev.body.clear(); 00490 rev.body.push_back(ID::posLiteralFromAtom(rev.head[0])); 00491 rev.head.clear(); 00492 rev.head.push_back(ID::atomFromLiteral(b)); 00493 ID revID = reg->storeRule(rev); 00494 DBGLOG(DBG, "Adding reversed rule " << printToString<RawPrinter>(revID, reg)); 00495 idb.push_back(revID); 00496 }else{ 00497 Rule rev = kvrule; 00498 rev.body.clear(); 00499 rev.head.clear(); 00500 rev.body.push_back(ID::posLiteralFromAtom(rev.head[0])); 00501 rev.body.push_back(ID::nafLiteralFromAtom(ID::atomFromLiteral(b))); 00502 rev.kind |= ID::SUBKIND_RULE_CONSTRAINT; 00503 ID revID = reg->storeRule(rev); 00504 DBGLOG(DBG, "Adding reversed constraint: " << printToString<RawPrinter>(revID, reg)); 00505 idb.push_back(revID); 00506 } 00507 } 00508 } 00509 } 00510 } 00511 assert(symbolicSetSize != -1 && "found aggregate without symbolic set"); 00512 00513 // actual rewriting 00514 DBGLOG(DBG, "Generating new aggregate or external atom"); 00515 ID valueVariable; 00516 // boolean external atoms can only be used for range queries but not if we need the exact value 00517 bool useBooleanEa = (ctxdata.mode == AggregatePlugin::CtxData::ExtBlRewrite); // && (aatom.tuple[0] == ID_FAIL || aatom.tuple[1].address == ID::TERM_BUILTIN_LE || aatom.tuple[1].address == ID::TERM_BUILTIN_LT || aatom.tuple[1].address == ID::TERM_BUILTIN_GE || aatom.tuple[1].address == ID::TERM_BUILTIN_GT) && (aatom.tuple[4] == ID_FAIL || aatom.tuple[3].address == ID::TERM_BUILTIN_LE || aatom.tuple[3].address == ID::TERM_BUILTIN_LT || aatom.tuple[3].address == ID::TERM_BUILTIN_GE || aatom.tuple[3].address == ID::TERM_BUILTIN_GT) ); 00518 bool negate = false; 00519 switch (ctxdata.mode) { 00520 case AggregatePlugin::CtxData::ExtRewrite: 00521 case AggregatePlugin::CtxData::ExtBlRewrite: 00522 { 00523 // Construct the external atom as follows: 00524 // Input is 00525 // i1. the predicate name of the key rule generated above 00526 // i2. the predicate name of the input rule generated above 00527 // i3. (optional) the bounds if both are <= (or missing) and a boolean EA is used 00528 // Output is 00529 // o1. the list of variables determined above in step A2 00530 // o2. (optional) the function value if no boolean EA is used 00531 DBGLOG(DBG, "Constructing aggregate replacing external atom"); 00532 ExternalAtom eaReplacement(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_EXTERNAL); 00533 std::stringstream eaName; 00534 eaName << aggregateFunctionToExternalAtomName(aatom.tuple[2]) << (useBooleanEa ? "bl" : ""); 00535 Term exPred(ID::MAINKIND_TERM | ID::SUBKIND_TERM_CONSTANT, eaName.str()); 00536 eaReplacement.predicate = reg->storeTerm(exPred); 00537 // i1 00538 eaReplacement.inputs.push_back(keyPredID); 00539 // i2 00540 eaReplacement.inputs.push_back(inputPredID); 00541 // i3 00542 if (useBooleanEa){ 00543 if (aatom.tuple[0] != ID_FAIL && aatom.tuple[1].address == ID::TERM_BUILTIN_EQ && aatom.tuple[4] == ID_FAIL) { 00544 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[0]); 00545 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[0]); 00546 } 00547 else if (aatom.tuple[4] != ID_FAIL && aatom.tuple[3].address == ID::TERM_BUILTIN_EQ && aatom.tuple[0] == ID_FAIL) { 00548 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[4]); 00549 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[4]); 00550 } 00551 else if (aatom.tuple[0] != ID_FAIL && aatom.tuple[1].address == ID::TERM_BUILTIN_NE && aatom.tuple[4] == ID_FAIL) { 00552 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[0]); 00553 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[0]); 00554 negate = true; 00555 } 00556 else if (aatom.tuple[4] != ID_FAIL && aatom.tuple[3].address == ID::TERM_BUILTIN_NE && aatom.tuple[0] == ID_FAIL) { 00557 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[4]); 00558 eaReplacement.inputs.push_back(ID::termFromInteger(ID::TERM_BUILTIN_LE)); eaReplacement.inputs.push_back(aatom.tuple[4]); 00559 negate = true; 00560 } 00561 else{ 00562 if (aatom.tuple[0] != ID_FAIL) { eaReplacement.inputs.push_back(ID::termFromInteger(aatom.tuple[1].address)); eaReplacement.inputs.push_back(aatom.tuple[0]); } else { eaReplacement.inputs.push_back(reg->storeConstantTerm("none")); eaReplacement.inputs.push_back(reg->storeConstantTerm("none")); } 00563 if (aatom.tuple[4] != ID_FAIL) { eaReplacement.inputs.push_back(ID::termFromInteger(aatom.tuple[3].address)); eaReplacement.inputs.push_back(aatom.tuple[4]); } else { eaReplacement.inputs.push_back(reg->storeConstantTerm("none")); eaReplacement.inputs.push_back(reg->storeConstantTerm("none")); } 00564 } 00565 } 00566 // o1 00567 BOOST_FOREACH (ID t, symbolicSetVarsIntersectingRemainingBody) { 00568 eaReplacement.tuple.push_back(t); 00569 } 00570 // in case of = comparison, reuse the existing variable 00571 if (aatom.tuple[1].address == ID::TERM_BUILTIN_EQ) valueVariable = aatom.tuple[0]; 00572 else if (aatom.tuple[3].address == ID::TERM_BUILTIN_EQ) valueVariable = aatom.tuple[4]; 00573 else { 00574 std::stringstream var; 00575 var << prefix << aggIndex++; 00576 valueVariable = reg->storeVariableTerm(var.str()); 00577 } 00578 // o2 00579 if (!useBooleanEa) eaReplacement.tuple.push_back(valueVariable); 00580 00581 // store external atom and add its ID to the rule body 00582 newRule.body.push_back(b.isNaf() ^ negate ? ID::nafLiteralFromAtom(reg->eatoms.storeAndGetID(eaReplacement)) : ID::posLiteralFromAtom(reg->eatoms.storeAndGetID(eaReplacement))); 00583 00584 // make the rule know that it contains an external atom 00585 newRule.kind |= ID::PROPERTY_RULE_EXTATOMS; 00586 } 00587 break; 00588 case AggregatePlugin::CtxData::Simplify: 00589 { 00590 // in case of = comparison, reuse the existing variable 00591 if (aatom.tuple[1].address == ID::TERM_BUILTIN_EQ) valueVariable = aatom.tuple[0]; 00592 else if (aatom.tuple[3].address == ID::TERM_BUILTIN_EQ) valueVariable = aatom.tuple[4]; 00593 else { 00594 std::stringstream var; 00595 var << prefix << aggIndex++; 00596 valueVariable = reg->storeVariableTerm(var.str()); 00597 } 00598 00599 DBGLOG(DBG, "Creating simplified atom"); 00600 AggregateAtom simplifiedaatom(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_AGGREGATE); 00601 simplifiedaatom.tuple[0] = valueVariable; 00602 simplifiedaatom.tuple[1] = ID::termFromBuiltin(ID::TERM_BUILTIN_EQ); 00603 simplifiedaatom.tuple[2] = aatom.tuple[2]; 00604 simplifiedaatom.tuple[3] = ID_FAIL; 00605 simplifiedaatom.tuple[4] = ID_FAIL; 00606 00607 DBGLOG(DBG, "Creating aggregate literal"); 00608 OrdinaryAtom oatom(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_ORDINARYN | ID::PROPERTY_AUX); 00609 oatom.tuple.push_back(inputPredID); 00610 DBGLOG(DBG, "Adding body variables shared with symbolic set"); 00611 BOOST_FOREACH (ID var, symbolicSetVarsIntersectingRemainingBody) { 00612 DBGLOG(DBG, "Adding body variable of symbolic set to simplified aggregate: " << printToString<RawPrinter>(var, reg)); 00613 oatom.tuple.push_back(var); 00614 } 00615 DBGLOG(DBG, "Adding variables of the symboic set"); 00616 for (int i = 0; i < symbolicSetSize; i++) { 00617 std::stringstream var; 00618 var << prefix << aggIndex++; 00619 ID varID = reg->storeVariableTerm(var.str()); 00620 DBGLOG(DBG, "Adding symbolic set variable to simplified aggregate: " << printToString<RawPrinter>(varID, reg)); 00621 simplifiedaatom.variables.push_back(varID); 00622 oatom.tuple.push_back(varID); 00623 } 00624 simplifiedaatom.literals.push_back(ID::posLiteralFromAtom(reg->storeOrdinaryAtom(oatom))); 00625 00626 DBGLOG(DBG, "Adding aggregate to rule"); 00627 newRule.body.push_back(b.isNaf() ? ID::nafLiteralFromAtom(reg->aatoms.storeAndGetID(simplifiedaatom)) : ID::posLiteralFromAtom(reg->aatoms.storeAndGetID(simplifiedaatom))); 00628 } 00629 break; 00630 } 00631 00632 // add (at most) two atoms reflecting the original left and right comparator 00633 if (!useBooleanEa){ 00634 if (aatom.tuple[0] != ID_FAIL && aatom.tuple[1].address != ID::TERM_BUILTIN_EQ) { 00635 BuiltinAtom bi(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_BUILTIN); 00636 bi.tuple.push_back(aatom.tuple[1]); 00637 bi.tuple.push_back(aatom.tuple[0]); 00638 warnMaxint(ctx, aatom.tuple[0]); 00639 bi.tuple.push_back(valueVariable); 00640 newRule.body.push_back(ID::posLiteralFromAtom(reg->batoms.storeAndGetID(bi))); 00641 } 00642 if (aatom.tuple[4] != ID_FAIL && aatom.tuple[3].address != ID::TERM_BUILTIN_EQ) { 00643 BuiltinAtom bi(ID::MAINKIND_ATOM | ID::SUBKIND_ATOM_BUILTIN); 00644 bi.tuple.push_back(aatom.tuple[3]); 00645 bi.tuple.push_back(valueVariable); 00646 bi.tuple.push_back(aatom.tuple[4]); 00647 warnMaxint(ctx, aatom.tuple[4]); 00648 newRule.body.push_back(ID::posLiteralFromAtom(reg->batoms.storeAndGetID(bi))); 00649 } 00650 } 00651 } 00652 else { 00653 // take it as it is 00654 newRule.body.push_back(b); 00655 } 00656 } 00657 00658 // add the new rule to the IDB 00659 if (newRule.head.size() == 1 && newRule.body.size() == 0 && newRule.head[0].isOrdinaryGroundAtom()) { 00660 DBGLOG(DBG, "Adding fact " + printToString<RawPrinter>(newRule.head[0], reg)); 00661 edb->setFact(newRule.head[0].address); 00662 } 00663 else { 00664 ID newRuleID = reg->storeRule(newRule); 00665 idb.push_back(newRuleID); 00666 DBGLOG(DBG, "Adding rule " + printToString<RawPrinter>(newRuleID, reg)); 00667 } 00668 } 00669 00670 void AggregateRewriter::prepareRewrittenProgram(InterpretationPtr newEdb, ProgramCtx& ctx) { 00671 // go through all rules 00672 newIdb.clear(); 00673 BOOST_FOREACH (ID rid, ctx.idb) { 00674 rewriteRule(ctx, newEdb, newIdb, ctx.registry()->rules.getByID(rid)); 00675 } 00676 00677 // eliminate duplicates 00678 int shift = 0; 00679 InterpretationPtr intr(new Interpretation(ctx.registry())); 00680 for (int ridx = 0; ridx < newIdb.size() - shift; ){ 00681 // move next actual rule to the current position (respecting skipped rules) 00682 newIdb[ridx] = newIdb[ridx + shift]; 00683 if (intr->getFact(newIdb[ridx].address)){ 00684 // skip rule 00685 shift++; 00686 }else{ 00687 // no duplicate: remember rule 00688 intr->setFact(newIdb[ridx].address); 00689 ridx++; 00690 } 00691 } 00692 newIdb.resize(newIdb.size() - shift); 00693 00694 #ifndef NDEBUG 00695 std::stringstream programstring; 00696 RawPrinter printer(programstring, ctx.registry()); 00697 BOOST_FOREACH (ID ruleId, newIdb) { 00698 printer.print(ruleId); 00699 programstring << std::endl; 00700 } 00701 DBGLOG(DBG, "Aggregate-free rewritten program:" << std::endl << programstring.str()); 00702 #endif 00703 } 00704 00705 void AggregateRewriter::rewrite(ProgramCtx& ctx) { 00706 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 00707 if (ctxdata.enabled) { 00708 DBGLOG(DBG, "Aggregates are enabled -> rewrite program"); 00709 InterpretationPtr newEdb(new Interpretation(ctx.registry())); 00710 if (!!ctx.edb) newEdb->add(*ctx.edb); 00711 prepareRewrittenProgram(newEdb, ctx); 00712 ctx.edb = newEdb; 00713 ctx.idb = newIdb; 00714 } 00715 else { 00716 // plugin disabled: the program must not contain aggregates in this case 00717 DBGLOG(DBG, "Aggregates are disabled -> checking if program does not contain any"); 00718 BOOST_FOREACH (ID ruleID, ctx.idb) { 00719 const Rule& rule = ctx.registry()->rules.getByID(ruleID); 00720 BOOST_FOREACH (ID b, rule.body) { 00721 if (b.isAggregateAtom()) { 00722 throw GeneralError("Aggregates have been disabled but rule\n \"" + printToString<RawPrinter>(ruleID, ctx.registry()) + "\"\ncontains \"" + printToString<RawPrinter>(b, ctx.registry()) + "\""); 00723 } 00724 } 00725 } 00726 } 00727 } 00728 00729 } // anonymous namespace 00730 00731 00732 // rewrite program 00733 PluginRewriterPtr AggregatePlugin::createRewriter(ProgramCtx& ctx) 00734 { 00735 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 00736 // if( !ctxdata.enabled ) 00737 // return PluginRewriterPtr(); 00738 00739 // Always create the rewriter! It will internall check if the plugin is enabled; if not, then the rewriter checks if the program does not contain aggregates. 00740 return PluginRewriterPtr(new AggregateRewriter(ctxdata)); 00741 } 00742 00743 00744 // register auxiliary printer for strong negation auxiliaries 00745 void AggregatePlugin::setupProgramCtx(ProgramCtx& ctx) 00746 { 00747 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 00748 if( !ctxdata.enabled ) 00749 return; 00750 00751 RegistryPtr reg = ctx.registry(); 00752 } 00753 00754 00755 namespace 00756 { 00757 00758 class AggAtom : public PluginAtom 00759 { 00760 protected: 00761 bool booleanAtom; 00762 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) = 0; 00763 00764 std::string getName(std::string aggFunction, bool booleanAtom) { 00765 std::stringstream ss; 00766 ss << aggFunction << (booleanAtom ? "bl" : ""); 00767 return ss.str(); 00768 } 00769 00770 public: 00771 00772 AggAtom(std::string aggFunction, bool booleanAtom = false) 00773 : PluginAtom(getName(aggFunction, booleanAtom), false), booleanAtom(booleanAtom) { 00774 prop.variableOutputArity = true; 00775 00776 addInputPredicate(); 00777 addInputPredicate(); 00778 if (booleanAtom){ 00779 prop.providesPartialAnswer = true; 00780 addInputConstant(); 00781 addInputConstant(); 00782 addInputConstant(); 00783 addInputConstant(); 00784 } 00785 00786 setOutputArity(1); 00787 } 00788 00789 virtual std::vector<Query> 00790 splitQuery(const Query& query, const ExtSourceProperties& prop) { 00791 std::vector<Query> atomicQueries; 00792 00793 // we can answer the query separately for each key 00794 00795 // go through all input atoms 00796 bm::bvector<>::enumerator en = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().first(); 00797 bm::bvector<>::enumerator en_end = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().end(); 00798 00799 // collect all keys for which we need to evaluate the aggregate function 00800 int arity = -1; 00801 while (en < en_end) { 00802 const OrdinaryAtom& oatom = registry->ogatoms.getByAddress(*en); 00803 00804 // extract the key of this atom 00805 Tuple key; 00806 key.clear(); 00807 // key atom 00808 if (oatom.tuple[0] == query.input[0]) { 00809 // elements >= 1 are the key 00810 for (int i = 1; i < oatom.tuple.size(); ++i) { 00811 key.push_back(oatom.tuple[i]); 00812 } 00813 00814 // now shrink the interpretations to the according value atoms 00815 Query sub = query; 00816 InterpretationPtr subIntr(new Interpretation(query.ctx->registry())); 00817 InterpretationPtr subAssigned; 00818 if (!!query.assigned) subAssigned.reset(new Interpretation(query.ctx->registry())); 00819 sub.interpretation = subIntr; 00820 sub.assigned = subAssigned; 00821 bm::bvector<>::enumerator en2 = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().first(); 00822 bm::bvector<>::enumerator en2_end = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().end(); 00823 while (en < en_end) { 00824 const OrdinaryAtom& oatom = registry->ogatoms.getByAddress(*en2); 00825 00826 if (oatom.tuple[0] == query.input[1]) { 00827 // does it belong to this key? 00828 bool match = true; 00829 for (int i = 1; i < oatom.tuple.size(); ++i) { 00830 if (key[i - 1] != oatom.tuple[i]) { 00831 match = false; 00832 break; 00833 } 00834 } 00835 if (match) { 00836 subIntr->setFact(query.interpretation->getFact(*en)); 00837 if (!!query.assigned) subAssigned->setFact(query.assigned->getFact(*en)); 00838 } 00839 } 00840 00841 en++; 00842 } 00843 00844 atomicQueries.push_back(query); 00845 } 00846 00847 en++; 00848 } 00849 00850 return atomicQueries; 00851 } 00852 00853 virtual void 00854 retrieve(const Query& query, Answer& answer) throw (PluginError) { 00855 Registry ®istry = *getRegistry(); 00856 00857 // extract all keys 00858 bm::bvector<>::enumerator en = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().first(); 00859 bm::bvector<>::enumerator en_end = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().end(); 00860 std::vector<Tuple> keys; 00861 int arity = -1; 00862 while (en < en_end) { 00863 const OrdinaryAtom& oatom = registry.ogatoms.getByAddress(*en); 00864 00865 // for a key atom: 00866 if (oatom.tuple[0] == query.input[0]) { 00867 // take the first "arity" terms 00868 Tuple key; 00869 for (int i = 1; i < oatom.tuple.size(); ++i) { 00870 key.push_back(oatom.tuple[i]); 00871 } 00872 keys.push_back(key); 00873 // all key atoms must have the same arity 00874 assert (arity == -1 || arity == key.size()); 00875 arity = key.size(); 00876 } 00877 00878 en++; 00879 } 00880 00881 // go through all value atoms 00882 en = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().first(); 00883 en_end = query.ctx->registry()->eatoms.getByID(query.eatomID).getPredicateInputMask()->getStorage().end(); 00884 00885 boost::unordered_map<Tuple, std::vector<Tuple> > trueTuples, mightBeTrueTuples; 00886 while (en < en_end) { 00887 const OrdinaryAtom& oatom = registry.ogatoms.getByAddress(*en); 00888 00889 // for a value input atom: 00890 if (oatom.tuple[0] == query.input[1]) { 00891 // if there is a value atom, then there must also be a key atom and the arity must be known 00892 assert (arity != -1); 00893 00894 // take the first "arity" terms 00895 Tuple key; 00896 for (int i = 1; i <= arity; ++i) { 00897 key.push_back(oatom.tuple[i]); 00898 } 00899 00900 // encoding "extbl": value has the form [key,value,substitution_of_all_variables,symbolic_set_element_idx,count_of_all_variables]; the last two elements need to be stripped off 00901 Tuple value; 00902 for (uint32_t j = arity + 1; j < oatom.tuple.size(); ++j) { 00903 if (booleanAtom && (j == oatom.tuple.size() - (2 + oatom.tuple[oatom.tuple.size() - 1].address))) break; 00904 value.push_back(oatom.tuple[j]); 00905 } 00906 if ((!query.assigned || query.assigned->getFact(*en)) && query.interpretation->getFact(*en)){ 00907 // remove from mightBeTrue if present (might happen with encoding "extbl" since multiple value atoms can contain the same actual value) 00908 if (booleanAtom) { 00909 std::vector<Tuple>::iterator it = std::find(mightBeTrueTuples[key].begin(), mightBeTrueTuples[key].end(), value); 00910 if (it != mightBeTrueTuples[key].end()) mightBeTrueTuples[key].erase(it); 00911 } 00912 trueTuples[key].push_back(value); 00913 }else if (!!query.assigned && !query.assigned->getFact(*en)){ 00914 // skip if present in true (might happen with encoding "extbl" since multiple value atoms can contain the same actual value) 00915 if (!booleanAtom || (std::find(trueTuples[key].begin(), trueTuples[key].end(), value) == trueTuples[key].end())) { 00916 mightBeTrueTuples[key].push_back(value); 00917 } 00918 } 00919 } 00920 00921 en++; 00922 } 00923 00924 // compute for each key in tuples the aggregate function 00925 typedef std::pair<Tuple, std::vector<Tuple> > Pair; 00926 BOOST_FOREACH (Tuple key, keys) { 00927 bool def = false; 00928 uint32_t minFunctionValue = 0; 00929 uint32_t maxFunctionValue = 0; 00930 compute(trueTuples[key], mightBeTrueTuples[key], &minFunctionValue, &maxFunctionValue, &def); 00931 00932 // output 00933 if (def) { 00934 Tuple result = key; 00935 if (booleanAtom){ 00936 int lowerBound = 0; 00937 int upperBound = -2; // means infinity 00938 00939 if (query.input[2].isTerm() && query.input[2].isIntegerTerm()){ 00940 DBGLOG(DBG, "Aggregate has a left bound"); 00941 if (query.input[2].address == ID::TERM_BUILTIN_LE) lowerBound = query.input[3].address > lowerBound ? query.input[3].address : lowerBound; 00942 if (query.input[2].address == ID::TERM_BUILTIN_LT) lowerBound = query.input[3].address + 1 > lowerBound ? query.input[3].address + 1 : lowerBound; 00943 if (query.input[2].address == ID::TERM_BUILTIN_GE) upperBound = (query.input[3].address < upperBound || upperBound == -2) ? query.input[3].address : upperBound; 00944 if (query.input[2].address == ID::TERM_BUILTIN_GT) upperBound = (query.input[3].address - 1 < upperBound || upperBound == -2) ? query.input[3].address - 1 : upperBound; 00945 } 00946 if (query.input[4].isTerm() && query.input[4].isIntegerTerm()){ 00947 DBGLOG(DBG, "Aggregate has a right bound"); 00948 if (query.input[4].address == ID::TERM_BUILTIN_GE) lowerBound = query.input[5].address > lowerBound ? query.input[5].address : lowerBound; 00949 if (query.input[4].address == ID::TERM_BUILTIN_GT) lowerBound = query.input[5].address + 1 > lowerBound ? query.input[5].address + 1 : lowerBound; 00950 if (query.input[4].address == ID::TERM_BUILTIN_LE) upperBound = (query.input[5].address < upperBound || upperBound == -2) ? query.input[5].address : upperBound; 00951 if (query.input[4].address == ID::TERM_BUILTIN_LT) upperBound = (query.input[5].address - 1 < upperBound || upperBound == -2) ? query.input[5].address - 1 : upperBound; 00952 } 00953 00954 #ifndef NDEBUG 00955 std::stringstream ss; 00956 ss << "true:"; 00957 BOOST_FOREACH (Tuple t, trueTuples[key]){ ss << " " << t[0].address; } 00958 ss << ", can be true:"; 00959 BOOST_FOREACH (Tuple t, mightBeTrueTuples[key]){ ss << " " << t[0].address; } 00960 if (upperBound == -2){ 00961 DBGLOG(DBG, "Aggregate atom returned possible range [" << minFunctionValue << ", " << maxFunctionValue << "]; range query is [" << lowerBound << ", infinity] with input " << ss.str()); 00962 }else{ 00963 DBGLOG(DBG, "Aggregate atom returned possible range [" << minFunctionValue << ", " << maxFunctionValue << "]; range query is [" << lowerBound << ", " << upperBound << "] with input " << ss.str()); 00964 } 00965 #endif 00966 if (minFunctionValue >= lowerBound && (upperBound == -2 || maxFunctionValue <= upperBound)){ 00967 DBGLOG(DBG, "Aggregate is true"); 00968 answer.get().push_back(result); 00969 }else if (maxFunctionValue < lowerBound || (upperBound != -2 && minFunctionValue > upperBound)){ 00970 DBGLOG(DBG, "Aggregate is false"); 00971 }else{ 00972 DBGLOG(DBG, "Aggregate is unknown"); 00973 answer.getUnknown().push_back(result); 00974 } 00975 }else{ 00976 assert (minFunctionValue == maxFunctionValue && "Non-boolean aggregates must deliver a definite value"); 00977 result.push_back(ID::termFromInteger(minFunctionValue)); 00978 answer.get().push_back(result); 00979 } 00980 } 00981 } 00982 } 00983 }; 00984 00985 class MaxAtom : public AggAtom 00986 { 00987 private: 00988 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 00989 00990 *defined = false; 00991 *minFunctionValue = 0; 00992 *maxFunctionValue = 0; 00993 BOOST_FOREACH (Tuple t, trueInput) { 00994 *minFunctionValue = t[0].address > *minFunctionValue ? t[0].address : *minFunctionValue; 00995 *defined = true; 00996 } 00997 *maxFunctionValue = *minFunctionValue; 00998 BOOST_FOREACH (Tuple t, mightBeTrueInput) { 00999 *maxFunctionValue = t[0].address > *maxFunctionValue ? t[0].address : *maxFunctionValue; 01000 *defined = true; 01001 } 01002 } 01003 01004 public: 01005 MaxAtom(bool booleanAtom = false) : AggAtom("max", booleanAtom) {} 01006 }; 01007 01008 class MinAtom : public AggAtom 01009 { 01010 private: 01011 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 01012 01013 *defined = false; 01014 *minFunctionValue = 0; 01015 *maxFunctionValue = 0; 01016 BOOST_FOREACH (Tuple t, trueInput) { 01017 *maxFunctionValue = t[0].address < *maxFunctionValue || !(*defined) ? t[0].address : *maxFunctionValue; 01018 *defined = true; 01019 } 01020 *minFunctionValue = *maxFunctionValue; 01021 BOOST_FOREACH (Tuple t, mightBeTrueInput) { 01022 *minFunctionValue = t[0].address < *minFunctionValue || !(*defined) ? t[0].address : *minFunctionValue; 01023 *defined = true; 01024 } 01025 } 01026 01027 public: 01028 MinAtom(bool booleanAtom = false) : AggAtom("min", booleanAtom) {} 01029 }; 01030 01031 class SumAtom : public AggAtom 01032 { 01033 private: 01034 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 01035 01036 *defined = true; 01037 *minFunctionValue = 0; 01038 *maxFunctionValue = 0; 01039 int nfv = 0; 01040 int xfv = 0; 01041 BOOST_FOREACH (Tuple t, trueInput) { 01042 if (t[0].isConstantTerm()){ 01043 nfv--; 01044 xfv--; 01045 }else{ 01046 nfv += t[0].address; 01047 xfv += t[0].address; 01048 } 01049 } 01050 BOOST_FOREACH (Tuple t, mightBeTrueInput) { 01051 if (t[0].isConstantTerm()){ 01052 nfv--; 01053 }else{ 01054 xfv += t[0].address; 01055 } 01056 } 01057 *minFunctionValue = (nfv >= 0 ? nfv : 0); 01058 *maxFunctionValue = (xfv >= 0 ? xfv : 0); 01059 } 01060 01061 public: 01062 SumAtom(bool booleanAtom = false) : AggAtom("sum", booleanAtom) {} 01063 }; 01064 01065 class TimesAtom : public AggAtom 01066 { 01067 private: 01068 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 01069 01070 *defined = false; 01071 *minFunctionValue = 1; 01072 *maxFunctionValue = 1; 01073 BOOST_FOREACH (Tuple t, trueInput) { 01074 *minFunctionValue *= t[0].address; 01075 *maxFunctionValue *= t[0].address; 01076 *defined = true; 01077 } 01078 BOOST_FOREACH (Tuple t, mightBeTrueInput) { 01079 if (t[0].address == 0) *minFunctionValue = 0; 01080 if (t[0].address != 0) *maxFunctionValue *= t[0].address; 01081 *defined = true; 01082 } 01083 } 01084 01085 public: 01086 TimesAtom(bool booleanAtom = false) : AggAtom("times", booleanAtom) {} 01087 }; 01088 01089 class AvgAtom : public AggAtom 01090 { 01091 private: 01092 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 01093 01094 *defined = false; 01095 *minFunctionValue = 0; 01096 *maxFunctionValue = 0; 01097 int cnt = 0; 01098 BOOST_FOREACH (Tuple t, trueInput) { 01099 *minFunctionValue += t[0].address; 01100 *maxFunctionValue += t[0].address; 01101 cnt++; 01102 *defined = true; 01103 } 01104 01105 int smallest = -1; 01106 int largest = -1; 01107 int smallestcnt = 0; 01108 int largestcnt = 0; 01109 BOOST_FOREACH (Tuple t, mightBeTrueInput) { 01110 if (t[0].address == smallest) { smallestcnt++; } 01111 if (t[0].address == largest) { largestcnt++; } 01112 if (smallest == -1 || t[0].address < smallest) { smallest = t[0].address; smallestcnt = 1; } 01113 if (largest == -1 || t[0].address > largest) { largest = t[0].address; largestcnt = 1; } 01114 } 01115 if (*defined) { 01116 if (smallest != -1) { 01117 if ((*minFunctionValue + smallest * smallestcnt) / (cnt + smallestcnt) < (*minFunctionValue / cnt)){ 01118 *minFunctionValue += smallest * smallestcnt; 01119 cnt += smallestcnt; 01120 } 01121 } 01122 if (largest != -1) { 01123 if ((*maxFunctionValue + largest * largestcnt) / (cnt + largestcnt) > (*maxFunctionValue / cnt)){ 01124 *maxFunctionValue += largest * largestcnt; 01125 cnt += largestcnt; 01126 } 01127 } 01128 *minFunctionValue /= cnt; 01129 *maxFunctionValue /= cnt; 01130 }else{ 01131 if (smallest != -1) *minFunctionValue = smallest; 01132 if (largest != -1) *maxFunctionValue = largest; 01133 } 01134 } 01135 01136 public: 01137 AvgAtom(bool booleanAtom = false) : AggAtom("avg", booleanAtom) {} 01138 }; 01139 01140 class CountAtom : public AggAtom 01141 { 01142 private: 01143 virtual std::string aggFunction(){ return "count"; } 01144 01145 virtual void compute(const std::vector<Tuple>& trueInput, const std::vector<Tuple>& mightBeTrueInput, unsigned int* minFunctionValue, unsigned int* maxFunctionValue, bool* defined) { 01146 01147 *defined = true; 01148 *minFunctionValue = trueInput.size(); 01149 *maxFunctionValue = trueInput.size() + mightBeTrueInput.size(); 01150 } 01151 01152 public: 01153 CountAtom(bool booleanAtom = false) : AggAtom("count", booleanAtom) {} 01154 }; 01155 01156 } 01157 01158 01159 std::vector<PluginAtomPtr> AggregatePlugin::createAtoms(ProgramCtx& ctx) const 01160 { 01161 std::vector<PluginAtomPtr> ret; 01162 01163 // we have to do the program rewriting already here because it creates some side information that we need 01164 AggregatePlugin::CtxData& ctxdata = ctx.getPluginData<AggregatePlugin>(); 01165 01166 // return smart pointer with deleter (i.e., delete code compiled into this plugin) 01167 DBGLOG(DBG, "Adding aggregate external atoms"); 01168 ret.push_back(PluginAtomPtr(new MaxAtom(), PluginPtrDeleter<PluginAtom>())); 01169 ret.push_back(PluginAtomPtr(new MinAtom(), PluginPtrDeleter<PluginAtom>())); 01170 ret.push_back(PluginAtomPtr(new SumAtom(), PluginPtrDeleter<PluginAtom>())); 01171 ret.push_back(PluginAtomPtr(new TimesAtom(), PluginPtrDeleter<PluginAtom>())); 01172 ret.push_back(PluginAtomPtr(new AvgAtom(), PluginPtrDeleter<PluginAtom>())); 01173 ret.push_back(PluginAtomPtr(new CountAtom(), PluginPtrDeleter<PluginAtom>())); 01174 ret.push_back(PluginAtomPtr(new MaxAtom(true), PluginPtrDeleter<PluginAtom>())); 01175 ret.push_back(PluginAtomPtr(new MinAtom(true), PluginPtrDeleter<PluginAtom>())); 01176 ret.push_back(PluginAtomPtr(new SumAtom(true), PluginPtrDeleter<PluginAtom>())); 01177 ret.push_back(PluginAtomPtr(new TimesAtom(true), PluginPtrDeleter<PluginAtom>())); 01178 ret.push_back(PluginAtomPtr(new AvgAtom(true), PluginPtrDeleter<PluginAtom>())); 01179 ret.push_back(PluginAtomPtr(new CountAtom(true), PluginPtrDeleter<PluginAtom>())); 01180 01181 return ret; 01182 } 01183 01184 01185 DLVHEX_NAMESPACE_END 01186 01187 // this would be the code to use this plugin as a "real" plugin in a .so file 01188 // but we directly use it in dlvhex.cpp 01189 #if 0 01190 AggregatePlugin theAggregatePlugin; 01191 01192 // return plain C type s.t. all compilers and linkers will like this code 01193 extern "C" 01194 void * PLUGINIMPORTFUNCTION() 01195 { 01196 return reinterpret_cast<void*>(& DLVHEX_NAMESPACE theAggregatePlugin); 01197 } 01198 #endif 01199 01200 01201 // vim:expandtab:ts=4:sw=4: 01202 // mode: C++ 01203 // End: