#include #include #include #include #include #include #include #include "tbb/parallel_for.h" #include "tbb/parallel_for_each.h" #include "tbb/global_control.h" #include "tbb/concurrent_hash_map.h" using namespace std; using namespace std::chrono; using namespace tbb; using Tokens = vector; using Corpus = vector; using Indices = vector; using Map = concurrent_hash_map; typedef struct { int threads; } Args; /// Split a string into tokens Tokens split_string(const string& str, const char* delim = " ") { Tokens tokens; size_t end = 0; for(size_t start = 0; end != string::npos; start = end + 1) { end = str.find(delim, start); tokens.push_back(end == string::npos ? str.substr(start) : str.substr(start, end - start)); } return tokens; } /// Read bags of words from a file Corpus read_words(const char* path) { ifstream file(path); string line; Corpus data; while(getline(file, line)) { data.push_back(split_string(line)); } return data; } /// For each topic, find documents containing all the words from that topic vector match(const Corpus& documents, const Corpus& topics) { Map word_docs; parallel_for(0ul, documents.size(), [&](auto i) { for(auto& w: documents[i]) { Map::accessor acc; word_docs.insert(acc, w); acc->second.push_back(i); } }); parallel_for_each(word_docs, [](auto& kv) { sort(kv.second.begin(), kv.second.end()); }); vector rslt(topics.size()); Indices empty; parallel_for(0ul, topics.size(), [&](auto i) { auto& x = rslt[i]; size_t j = 0; for(auto& w: topics[i]) { auto range = word_docs.equal_range(w); auto& y = range.first == range.second ? empty : range.first->second; if(j++ == 0) { x = y; } else { Indices buf(min(x.size(), y.size())); auto end = set_intersection(x.begin(), x.end(), y.begin(), y.end(), buf.begin()); x = Indices(buf.begin(), end); } } }); return rslt; } /// Parse command line arguments (no validity checks are done) Args parse_args(int argc, char* argv[]) { Args args; args.threads = 1; int c; while((c = getopt(argc, argv, "n:")) != -1) { switch(c) { case 'n': args.threads = atoi(optarg); break; default: exit(1); } } return args; } int main(int argc, char* argv[]) { // Parse command line arguments auto args = parse_args(argc, argv); // Load data auto topics = read_words("topics.txt"); auto documents = read_words("documents.txt"); // Limit maximum number of threads (including master thread) global_control gc(global_control::max_allowed_parallelism, args.threads); // Match data auto start = system_clock::now(); auto rslt = match(documents, topics); auto end = system_clock::now(); auto elapsed = duration(end - start).count(); printf("Elapsed time: %0.2fs\n", elapsed); // Print a few summary statistics size_t s1 = 0; size_t s2 = 0; for(auto& d: rslt) { s1 += d.size() > 0 ? 1 : 0; s2 = max(s2, d.size()); } printf("%i\n%i\n", s1, s2); return 0; }