#include "DBInt.h"
#include <cstdlib>
#include <sstream>

#define ERROR cerr << "DBInt:" << __LINE__ << ": "

DBInt::DBInt()
    :fMysql(0)
{
}

DBInt::DBInt(const char* db, const char* host, const char* user, 
             const char* password)
    :fMysql(0)
{
    assert(this->Init(db,host,user,password));
}

DBInt::~DBInt()
{
    this->Close();
}

void DBInt::Close(void)
{
    mysql_close(fMysql);
    fMysql=0;
}

bool DBInt::Init(const char* db, const char* host, const char* user, 
                 const char* password)
{
    if (fMysql) this->Close();

    if (!host) host = "localhost";
    if (!user) user = getenv("USER");

    fMysql = mysql_init(0);
    if (!fMysql) {
        ERROR << "Can't initialize MySQL\n";
        return false;
    }
    else {
//        ERROR << "mysql " << user << "@" << host << ":" << db << " accessed\n";
    }
    if (!mysql_real_connect(fMysql,host,user,password,db,0,0,0)) {
        ERROR << "Can't connect to MySQL db " << db << ", "
              << user << "@" << host << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        this->Close();
        return false;
    }
    return true;
}

bool DBInt::GetFields(string table, vector<string>& fields)
{
    stringstream ss;
    ss << "DESCRIBE " << table << ";";
    if (mysql_query(fMysql,ss.str().c_str())) {
        ERROR << "Failed to describe table " << table << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }

    MYSQL_RES* res;
    if (!(res = mysql_store_result(fMysql))) {
        ERROR << "Failed to get result from describe table " << table << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }
    MYSQL_ROW row;
    while ((row = mysql_fetch_row(res))) fields.push_back(*row);
    mysql_free_result(res);

    return true;
}

bool DBInt::GetNthData(int id, int nth,
                       string table_name, string what, string order_by,
                       vector<string>& data)
{
    stringstream ss;

    ss.str("");
    ss << "SELECT " << what << " FROM " << table_name 
       << " WHERE id = " << id;
    if (order_by.size()) ss << " ORDER BY " << order_by;
    ss << ";";

    if (mysql_query(fMysql,ss.str().c_str())) {
        ERROR << "Failed to query table " << table_name 
              << " for id #" << id << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }

    MYSQL_RES* res;
    if (!(res = mysql_store_result(fMysql))) {
        ERROR << "Failed to store query of table " << table_name 
              << " at id #" << id << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }

    mysql_data_seek(res,nth);
    ss.str("");
    MYSQL_ROW row = mysql_fetch_row(res);
    int siz = mysql_num_fields(res);
    for (int i=0; i<siz; ++i) {
        data.push_back(row[i]);
    }
    mysql_free_result(res);

    return true;
}

bool DBInt::GetData(int id, string table_name, string what, string order_by,
                    vector<vector<string> >& data)
{
    stringstream ss;

    ss.str("");
    ss << "SELECT " << what << " FROM " << table_name;
    if (id >= 0) ss << " WHERE id = " << id;
    if (order_by.size()) ss << " ORDER BY " << order_by;
    ss << ";";

    if (mysql_query(fMysql,ss.str().c_str())) {
        ERROR << "Failed to query data for table " << table_name 
              << " for id #" << id << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }

    MYSQL_RES* res;
    if (!(res = mysql_store_result(fMysql))) {
        ERROR << "Failed to get result data for table " << table_name 
              << " with id = #" << id << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return false;
    }

    MYSQL_ROW row;;
    int siz = mysql_num_fields(res);
    while ((row = mysql_fetch_row(res))) {
        vector<string> vs;
        for (int i=0; i<siz; ++i) vs.push_back(row[i]);
        data.push_back(vs);
    }
    mysql_free_result(res);

    return true;

}

bool DBInt::SetData(int id, string table_name, string what, 
                    vector<string>& data)
{
    stringstream ss;

    ss.str("");
    ss << "DELETE FROM " << table_name << " WHERE id = " << id << ";";
    mysql_query(fMysql,ss.str().c_str());

    int siz = data.size();
    for (int i = 0; i < siz; ++i) {
        ss.str("");
        ss << "INSERT INTO " << table_name
           << " (id," << what << ") "
           << " VALUES( " << id << "," << data[i] << ");";
        if (mysql_query(fMysql,ss.str().c_str())) {
            ERROR << "Failed to insert into table " << table_name 
                  << " at id " << id << " and index " << i << "the values:\n"
                  << data[i] << endl;
            ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
            return false;
        }
    }
    return true;


}

// Return ID associated with query or 0
int DBInt::GetId(string table_name, string query)
{
    stringstream ss;

    ss.str("");
    ss << "SELECT id FROM " << table_name << " WHERE " << query << ";";

    if (mysql_query(fMysql,ss.str().c_str())) {
        ERROR << "Failed to query id for table " << table_name
              << " using query: \"" << query << "\"\n";
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return 0;
    }

    MYSQL_RES* res;
    if (!(res = mysql_store_result(fMysql))) {
        ERROR << "Failed to get id for table " << table_name 
              << " using query: \"" << query << "\"\n";
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return 0;
    }

    MYSQL_ROW row;;
    row = mysql_fetch_row(res);
    if (!row) {
        ERROR << "Failed to get row from result of querying table "
              << table_name << " with query:\n"
              << query << endl;
        ERROR << "MySQL error: " << mysql_error(fMysql) << endl;
        return 0;
    }
    int id = atoi(row[0]);
    mysql_free_result(res);
    return id;
}

string bracket_double(const char* name, double value)
{
    stringstream ss;
    if (value == 0.0)
        ss << " " << name << " > " << -0.01
           << " and " << name << " < " << +0.01 << " ";
    else
        ss << " " << name << " > " << value*0.99 
           << " and " << name << " < " << value*1.01 << " ";
    return ss.str();
}
