OILS / vendor / souffle / io / ReadStreamSQLite.h View on Github | oils.pub

205 lines, 124 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2021, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file ReadStreamSQLite.h
12 *
13 ***********************************************************************/
14
15#pragma once
16
17#include "souffle/RamTypes.h"
18#include "souffle/RecordTable.h"
19#include "souffle/SymbolTable.h"
20#include "souffle/io/ReadStream.h"
21#include "souffle/utility/MiscUtil.h"
22#include "souffle/utility/StringUtil.h"
23#include <cassert>
24#include <cstdint>
25#include <fstream>
26#include <map>
27#include <memory>
28#include <stdexcept>
29#include <string>
30#include <vector>
31#include <sqlite3.h>
32
33namespace souffle {
34
35class ReadStreamSQLite : public ReadStream {
36public:
37 ReadStreamSQLite(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
38 RecordTable& recordTable)
39 : ReadStream(rwOperation, symbolTable, recordTable), dbFilename(getFileName(rwOperation)),
40 relationName(rwOperation.at("name")) {
41 openDB();
42 checkTableExists();
43 prepareSelectStatement();
44 }
45
46 ~ReadStreamSQLite() override {
47 sqlite3_finalize(selectStatement);
48 sqlite3_close(db);
49 }
50
51protected:
52 /**
53 * Read and return the next tuple.
54 *
55 * Returns nullptr if no tuple was readable.
56 * @return
57 */
58 Own<RamDomain[]> readNextTuple() override {
59 if (sqlite3_step(selectStatement) != SQLITE_ROW) {
60 return nullptr;
61 }
62
63 Own<RamDomain[]> tuple = mk<RamDomain[]>(arity + auxiliaryArity);
64
65 uint32_t column;
66 for (column = 0; column < arity; column++) {
67 std::string element;
68 if (0 == sqlite3_column_bytes(selectStatement, column)) {
69 element = "";
70 } else {
71 element = reinterpret_cast<const char*>(sqlite3_column_text(selectStatement, column));
72
73 if (element.empty()) {
74 element = "";
75 }
76 }
77
78 try {
79 auto&& ty = typeAttributes.at(column);
80 switch (ty[0]) {
81 case 's': tuple[column] = symbolTable.encode(element); break;
82 case 'f': tuple[column] = ramBitCast(RamFloatFromString(element)); break;
83 case 'i':
84 case 'u':
85 case 'r': tuple[column] = RamSignedFromString(element); break;
86 default: fatal("invalid type attribute: `%c`", ty[0]);
87 }
88 } catch (...) {
89 std::stringstream errorMessage;
90 errorMessage << "Error converting number in column " << (column) + 1;
91 throw std::invalid_argument(errorMessage.str());
92 }
93 }
94
95 return tuple;
96 }
97
98 void executeSQL(const std::string& sql) {
99 assert(db && "Database connection is closed");
100
101 char* errorMessage = nullptr;
102 /* Execute SQL statement */
103 int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errorMessage);
104 if (rc != SQLITE_OK) {
105 std::stringstream error;
106 error << "SQLite error in sqlite3_exec: " << sqlite3_errmsg(db) << "\n";
107 error << "SQL error: " << errorMessage << "\n";
108 error << "SQL: " << sql << "\n";
109 sqlite3_free(errorMessage);
110 throw std::invalid_argument(error.str());
111 }
112 }
113
114 void throwError(const std::string& message) {
115 std::stringstream error;
116 error << message << sqlite3_errmsg(db) << "\n";
117 throw std::invalid_argument(error.str());
118 }
119
120 void prepareSelectStatement() {
121 std::stringstream selectSQL;
122 selectSQL << "SELECT * FROM '" << relationName << "'";
123 const char* tail = nullptr;
124 if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &selectStatement, &tail) != SQLITE_OK) {
125 throwError("SQLite error in sqlite3_prepare_v2: ");
126 }
127 }
128
129 void openDB() {
130 sqlite3_config(SQLITE_CONFIG_URI, 1);
131 if (sqlite3_open(dbFilename.c_str(), &db) != SQLITE_OK) {
132 throwError("SQLite error in sqlite3_open: ");
133 }
134 sqlite3_extended_result_codes(db, 1);
135 executeSQL("PRAGMA synchronous = OFF");
136 executeSQL("PRAGMA journal_mode = MEMORY");
137 }
138
139 void checkTableExists() {
140 sqlite3_stmt* tableStatement;
141 std::stringstream selectSQL;
142 selectSQL << "SELECT count(*) FROM sqlite_master WHERE type IN ('table', 'view') AND ";
143 selectSQL << " name = '" << relationName << "';";
144 const char* tail = nullptr;
145
146 if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &tableStatement, &tail) != SQLITE_OK) {
147 throwError("SQLite error in sqlite3_prepare_v2: ");
148 }
149
150 if (sqlite3_step(tableStatement) == SQLITE_ROW) {
151 int count = sqlite3_column_int(tableStatement, 0);
152 if (count > 0) {
153 sqlite3_finalize(tableStatement);
154 return;
155 }
156 }
157 sqlite3_finalize(tableStatement);
158 throw std::invalid_argument(
159 "Required table or view does not exist in " + dbFilename + " for relation " + relationName);
160 }
161
162 /**
163 * Return given filename or construct from relation name.
164 * Default name is [configured path]/[relation name].sqlite
165 *
166 * @param rwOperation map of IO configuration options
167 * @return input filename
168 */
169 static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
170 // legacy support for SQLite prior to 2020-03-18
171 // convert dbname to filename
172 auto name = getOr(rwOperation, "dbname", rwOperation.at("name") + ".sqlite");
173 name = getOr(rwOperation, "filename", name);
174
175 if (name.rfind("file:", 0) == 0 || name.rfind(":memory:", 0) == 0) {
176 return name;
177 }
178
179 if (name.front() != '/') {
180 name = getOr(rwOperation, "fact-dir", ".") + "/" + name;
181 }
182 return name;
183 }
184
185 const std::string dbFilename;
186 const std::string relationName;
187 sqlite3_stmt* selectStatement = nullptr;
188 sqlite3* db = nullptr;
189};
190
191class ReadSQLiteFactory : public ReadStreamFactory {
192public:
193 Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
194 RecordTable& recordTable) override {
195 return mk<ReadStreamSQLite>(rwOperation, symbolTable, recordTable);
196 }
197
198 const std::string& getName() const override {
199 static const std::string name = "sqlite";
200 return name;
201 }
202 ~ReadSQLiteFactory() override = default;
203};
204
205} /* namespace souffle */