Logo Search packages:      
Sourcecode: bayonne version File versions  Download package

postgres.cpp

// Copyright (C) 2000 Open Source Telecom Corporation.
//  
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
// 
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software 
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

#include <server.h>
#include <cc++/process.h>

#ifdef      HAVE_PGSQL_POSTGRES
#define     DLLIMPORT
#define     HAVE_NAMESPACE_STD
#define     HAVE_CXX_STRING_HEADER
#include <pgsql/libpq++.h>
#else
#include <libpq++.h>
#endif

#ifdef      CCXX_NAMESPACES
namespace ost {
using namespace std;
#endif

#define     SYM_SQLDRIVER     "sql.driver"
#define     SYM_ROWS    "sql.rows"
#define     SYM_COLS    "sql.cols"
#define     SYM_DATABASE      "sql.database"
#define     SYM_SQLERROR      "sql.error"

class PostgresTrunk : public TrunkImage
{
private:
      friend class PostgresModule;

      void characters(const unsigned char *text, unsigned len) {};
      void startElement(const unsigned char *name, const unsigned char **attrib) {};
      void endElement(const unsigned char *name) {};
      bool loader(Trunk *trk, trunkdata_t *data);
      char *mystr(const char *temp);

      PostgresTrunk();
      ~PostgresTrunk();
};

class PostgresModule : private Module, public Keydata, public Mutex
{
private:
      friend class PostgresTrunk;
      unsigned count;
      PgDatabase *dbase;

      modtype_t getType(void)
            {return MODULE_SQL;};

      char *getName(void)
            {return "sql";};

      TrunkImage *getXML(void)
            {return (TrunkImage *)new PostgresTrunk;};

      char *dispatch(Trunk *trunk);

      void connect(Trunk *trunk);
      void detach(Trunk *trunk);

public:
      PostgresModule();
} pgsql;


PostgresModule::PostgresModule() : Module(), Keydata("/bayonne/sql"), Mutex()
{
      static Keydata::Define keydefs[] = {
      {"database", "bayonne"},
      {NULL, NULL}};

      const char *cp;

      slog(Slog::levelDebug) << "load: postgres module" << endl;
      load(keydefs);
      driver->addModule(this);
      addSession();
      count = 0;
      dbase = NULL;

      cp = getLast("host");
      if(cp)
            Process::setEnv("PGHOST", cp, true);

      cp = getLast("port");
      if(cp)
            Process::setEnv("PGPORT", cp, true);

      slog(Slog::levelDebug) << "sql: loading postgres driver" << endl;
}

void PostgresModule::connect(Trunk *trunk)
{
      char buf[256];
      const char *dbname, *user, *password;
      enterMutex();
      if(!count || !dbase)
      {
            dbname = getLast("database");
            user = getLast("user");
            password = getLast("password");

            if(password)
                  snprintf(buf, sizeof(buf), "dbname=%s user=%s password=%s",
                        dbname, user, password);
            else if(user)
                  snprintf(buf, sizeof(buf), "dbname=%s user=%s",
                        dbname, user);
            else
                  snprintf(buf, sizeof(buf), "dbname=%s", dbname);

            dbase = new PgDatabase(buf);
            slog(Slog::levelDebug) << "sql: connecting database" << endl;
      }
      ++count;
      leaveMutex();
      trunk->setConst(SYM_SQLDRIVER, "postgres");
      trunk->setConst(SYM_DATABASE, getLast("database"));
      trunk->setSymbol(SYM_ROWS, 10);
      trunk->setSymbol(SYM_COLS, 10);
      trunk->setSymbol(SYM_SQLERROR, 64);
}

void PostgresModule::detach(Trunk *trunk)
{
      const char *cp = trunk->getSymbol(SYM_SQLDRIVER);
      if(!cp)
            return;

      if(stricmp(cp, "postgres"))
            return;

      enterMutex();
      if(!--count)
      {
            if(dbase)
            {
                  delete dbase;
                  dbase = NULL;
            }
            
            slog(Slog::levelDebug) << "sql: disconnecting database" << endl;
      }
      leaveMutex();
}

char *PostgresModule::dispatch(Trunk *trunk)
{
            trunkdata_t *data = getData(trunk);
        const char *key, *sql = trunk->getKeyword("query");
      const char *mem = trunk->getMember();
      char *table = trunk->getKeyword("table"), *opt, *tag, *cp;
      char cols[256], vals[256];
      Line *line = trunk->getScript();
      unsigned len = 0, clen = 0, vlen = 0, argc = 0;

      key = trunk->getKeyword("maxTime");
      if(!key)
            key = "60s";
      if(!mem)
            mem = "none";

      data->load.timeout = getSecTimeout(key);

      if(!sql && !stricmp(mem, "insert"))
      {
            if(!table)
                  table = trunk->getValue(NULL);
            if(!table)
                  return "insert-table-missing";

            while(argc < line->argc && clen < sizeof(cols) - 1 && vlen < sizeof(vals) - 1)
            {
                  opt = line->args[argc++];
                  if(*opt == '%')
                  {
                        tag = ++opt;
                        opt = trunk->getSymbol(opt);
                  }
                  else if(*opt == '=' && stricmp(opt, "=table") && stricmp(opt, "=maxTime"))
                  {
                        tag = ++opt;
                        opt = trunk->getContent(line->args[argc++]);
                  }     
                  else
                        continue;
                  
                  if(clen)
                        cols[clen++] = ',';
                  if(vlen)
                        vals[vlen++] = ',';
                  snprintf(cols + clen, sizeof(cols) - clen, "%s", tag);
                  clen = strlen(cols);
                  snprintf(vals + vlen, sizeof(vals) - vlen, "\'%s\'", opt);
                  vlen = strlen(vals);          
            }
            cols[clen] = 0;
            vals[vlen] = 0;
            snprintf(data->load.filepath, 250, 
                  "insert into %s (%s) values (%s)", table, cols, vals);
            len = strlen(data->load.filepath);
      }
      else while(!sql && len < 256 && NULL != (cp = trunk->getValue(NULL)))
      {
            snprintf(data->load.filepath + len, 256 - len, "%s", cp);
            len = strlen(data->load.filepath);
      }
      if(!sql)
            sql = data->load.filepath;

      slog(Slog::levelDebug) << "sql: " << sql << endl;
        data->load.attach = false;
        data->load.post = false;
        data->load.section = "";
      data->load.fail = NULL;
        key = trunk->getKeyword("maxTime");
        if(!key)
                key = "60s";
        data->load.timeout = getSecTimeout(key);
        data->load.parent = NULL;
        data->load.gosub = false;
        data->load.url = sql;
        data->load.vars = NULL;
        data->load.userid[0] = 0;
      pgsql.connect(trunk);
        return NULL;
}

PostgresTrunk::PostgresTrunk() : TrunkImage()
{
}

PostgresTrunk::~PostgresTrunk()
{
      purge();
}

bool PostgresTrunk::loader(Trunk *trunk, trunkdata_t *data)
{
      ExecStatusType status;
      const char *sql = data->load.url;
      const char *dvr = trunk->getSymbol(SYM_SQLDRIVER);
      unsigned rows, cols, row, col;
      const char **argv;
      char val[10];
      const char *errmsg;

      if(!dvr)
            pgsql.connect(trunk);

      trunk->setSymbol(SYM_ROWS, "0");
      trunk->setSymbol(SYM_COLS, "0");
      trunk->setSymbol(SYM_SQLERROR, "");

      if(stricmp(dvr, "postgres"))
      {
            trunk->setSymbol(SYM_SQLERROR, "invalid-driver");
            return false;
      }

      pgsql.enterMutex();
      if(!pgsql.dbase)
      {
            trunk->setSymbol(SYM_SQLERROR, "no-database");
            pgsql.leaveMutex();
            return false;
      }

      status = pgsql.dbase->Exec(sql);
      
      switch(status)
      {
      case PGRES_FATAL_ERROR:
            errmsg = pgsql.dbase->ErrorMessage();
            slog(Slog::levelCritical) << "sql: " << errmsg << endl;
            trunk->setSymbol(SYM_SQLERROR, errmsg);
            pgsql.leaveMutex();
            return false;
      case PGRES_NONFATAL_ERROR:
      case PGRES_BAD_RESPONSE:
            errmsg = pgsql.dbase->ErrorMessage();
            slog(Slog::levelError) << "sql: " << errmsg << endl;
            trunk->setSymbol(SYM_SQLERROR, errmsg);
            pgsql.leaveMutex();
            return false;
      case PGRES_EMPTY_QUERY:
            pgsql.leaveMutex();
      case PGRES_COMMAND_OK:
      case PGRES_COPY_OUT:
      case PGRES_COPY_IN:
            snprintf(val, sizeof(val), "%d", pgsql.dbase->CmdTuples());
            trunk->setSymbol(SYM_ROWS, val);    
            pgsql.leaveMutex();
            return true;
      }

      rows = pgsql.dbase->Tuples();
      cols = pgsql.dbase->Fields();

      getCompile("#header");
      argv = (const char **)MemPager::alloc(sizeof(char *) * (cols + 1));
      col = 0;
      while(col < cols)
      {
            argv[col] = mystr(pgsql.dbase->FieldName(col));
            ++col;
      }
      argv[col] = NULL;
      addCompile(0, "data", argv);

      putCompile(main);

      getCompile("#sql");

      row = 0;
      while(row < rows)
      {
            col = 0;
            while(col < cols)
            {
                  argv[col] = mystr(pgsql.dbase->GetValue(row, col));
                  ++col;
            }
            argv[col] = NULL;
            addCompile(0, "data", argv);
            ++row;
      }

      putCompile(current);
      
      pgsql.leaveMutex();
      trunk->setData("#sql");
      snprintf(val, sizeof(val), "%d", rows);
      trunk->setSymbol(SYM_ROWS, val);
      snprintf(val, sizeof(val), "%d", cols);
      trunk->setSymbol(SYM_COLS, val);
      return true;
}

char *PostgresTrunk::mystr(const char *temp)
{
      char *nt;
      unsigned len;

      if(!temp)
            return "";

      len = strlen(temp);

      while(len && isspace(temp[len - 1]))
            --len;

        nt = (char *)MemPager::alloc(len + 1);
        strncpy(nt, temp, len);
      nt[len] = 0;
        return nt;
}

#ifdef      CCXX_NAMESPACES
};
#endif

Generated by  Doxygen 1.6.0   Back to index