---------------------------------------------------------------------------------------------------------------------------------------------------
Program Listing for:  ComSocket.cpp
Project:  comsocket
Namespace:  c++
----------------------------------------------------------------------------------------------------------------------------------------------------

// ComSocket.cpp : Implementation of CComSocket
#include "stdafx.h"
#include "Sock_atl.h"
#include "comdef.h"
#include "atlconv.h"
#include "comutil.h"

#include "ComSocket.h"

#include <windows.h>
#include <winsock.h>

/////////////////////////////////////////////////////////////////////////////
// CComSocket

WSADATA wsaData;
int wsa_inited = 0;

// constructor just makes sure that winsock has been started appropriately
// and sets a few class variables to defaults
CComSocket::CComSocket() {

         sockinit_a();
         nsocket = -1;
         this->breceive_error = false;
         this->sreceive_error_text = "";
         this->bconnected = false;
 
}

// destructor doesn't do much, tries to close the socket if open
CComSocket::~CComSocket() {

         BSTR dummy;
         if (sockclose(&dummy) == S_OK) { ;
                  //MessageBox(NULL, "socket closed by destructor", "", 0);
         } else { ;
                  //MessageBox(NULL, "socket close destructor call failure", "", 0);

         }

}

// close the socket method
// note mostly we are using IDispatch here, e.g. calls from Foxpro or VBScript
// so technically you shouldn't return S_FALSE since that value is ignored in
// idispatch calls.  For simplicity I'm using an empty string returned as success
// if a non-empty string is returned, the string is the error that occurred.
STDMETHODIMP CComSocket::sockclose(BSTR *pVal)
{

         CComBSTR bstr_resp;
         bstr_resp = "";

         // if the socket is open (won't close an already closed socket)
         if (nsocket != -1) {

                  if (closesocket(nsocket)) {

                           // set error string
                           bstr_resp = "Error with socket close.";
                           *pVal = bstr_resp.Detach();
                           return (S_FALSE ); 

                  } else {

                           // return string is still empty, i.e. success
                           *pVal = bstr_resp.Detach();

                           // set socket to a value showing it is closed
                           nsocket = -1;
                           this->bconnected = false;
                           return (S_OK );
                  }

         } else {
                  bstr_resp = "invalid socket.";
                  *pVal = bstr_resp.Detach();
                  return ( S_FALSE );

         }
}

// for convenience I needed a string trim function, since we are dealing
// with wide strings the normal STL alltrim functions don't seem to work well
// this appears a very slight flaw in the STL, the only trim functions seem to assume
// non-wide strings, and fail with unicode type strings. (?)
STDMETHODIMP CComSocket::alltrim_a(BSTR inVal, BSTR *outVal)
{

         std::wstring string(inVal);

         std::wstring ret;
         ret = string;

         int n;
         int full_len;

         for (n= string.length() -1; n>= 0; n--) {
                  if ( string[n] == ' ' || string[n] == '\r' || string[n] == '\n' ) ;
                  else
                           break;
         }
         full_len = n+1;
         if (full_len == 0)
                  ret = string.substr(0, 0);
         else {
                  ret = string.substr(0, full_len );
         }

         CComBSTR str_back;
         str_back = ret.c_str();
         *outVal = str_back.Detach();

         return S_OK;

}

// this is the standard connection piece, tries to connect to a dotted address
// or urlname, at given port, empty string returned means success, otherwise the
// string returned is the error
STDMETHODIMP CComSocket::sockconnect(BSTR *pURL, int port, BSTR *pResult)
{

         struct sockaddr_in a;
         struct hostent *h;
         CComBSTR bstr_resp;
         bstr_resp = "";
         bool is_bad = false;
         int s;

         this->breceive_error = false;
         this->sreceive_error_text = "";

         std::wstring in_string(*pURL);

         TString in_string2;

         //TString is a typedef of my own in the header files:
         //typedef std::basic_string<_TCHAR> TString;

         // convert strings to right format
         h = gethostbyname( _com_util:: ConvertBSTRToString( *pURL)  );
         in_string2 = _com_util:: ConvertBSTRToString( *pURL) ;

         // the code below does another try for a ipv4 x.x.x.x format string
         // however I think gethostbyname now actually works with this in a normal 
         // format so the below code may be unnecessary. also there is a better way to do this
         // via inet_addr function I think.

         // i'm using an ifdef as a way to comment this stuff out, for java if (false) is a good
         // construct but not for vc++ lol. (java if (false) will not compile the code but ignore it)

#ifdef IPV4TEST

         if (false) { //h==NULL) {

                  // try connecting to IP string if failure to connect to NAME
                  int iPeer[4] ;
                  iPeer[3] = iPeer[2] = iPeer[1] = iPeer[0] = 0;

                  //in_string2 = alltrim(TString(in_string));
                  int npos;
                  int ncnt = 0;
                  int nlast = 0;

                  while (1) {

                           if (nlast > in_string2.length() - 1) {

                             is_bad = true;
                             break;

                           }
                           npos = in_string2.find(".", nlast);
                           if (npos == std::wstring::npos) {

                                    if (ncnt < 3) {
                                             is_bad = true;
                                             break;
                                    } else {
                                             iPeer[ncnt++] = atoi( in_string2.substr(nlast).c_str() );
                                             if (iPeer[ncnt] > 255) {
                                                      is_bad = true;
                                                      break;
                                             }
                                    }

                           }

                           else
                                    iPeer[ncnt++] = atoi( in_string2.substr(nlast, npos ).c_str() );

                           nlast = npos + 1;
                           if (ncnt >= 4)
                                    break;

                  }

                  /*  old test code here for old routine
                  char *tests;
                  tests = new char[12];
                  ::itoa(iPeer[0], tests, 10);
                  MessageBox(NULL, tests, "", 0);
                  ::itoa(iPeer[1], tests, 10);
                  MessageBox(NULL, tests, "", 0);
                  ::itoa(iPeer[2], tests, 10);
                  MessageBox(NULL, tests, "", 0);
                  ::itoa(iPeer[3], tests, 10);
                  MessageBox(NULL, tests, "", 0);
                  delete [] tests;

                  if (is_bad)
                           MessageBox(NULL, "is bad", "", 0);
                  */

                  if (! is_bad ) {

                           char cPeer[5] ;
                           cPeer[0] = (unsigned char)iPeer[0] ;
                           cPeer[1] = (unsigned char)iPeer[1] ;
                           cPeer[2] = (unsigned char)iPeer[2] ;
                           cPeer[3] = (unsigned char)iPeer[3] ;
                           cPeer[4] = (char)0 ;

                           
                           // test see if it connects for a string of type "63.230.230.145"
                           h = ::gethostbyaddr( cPeer,
                                    4,
                                    PF_INET);
                      
                            if (h==NULL) {

                                    //WSACleanup();
                                    bstr_resp = "URL connection failed";
                                    *pResult = bstr_resp.Detach();
                                    return S_OK;

                           }
                  } else {

                           bstr_resp = "URL connection failed";
                           *pResult = bstr_resp.Detach();
                           return S_OK;

                  }

         }

#endif

         
         // okay we are back to real executable code here

         a.sin_family = AF_INET;
         a.sin_port = htons(port);
         //a.sin_addr.s_addr = inet_addr(in_string);

         memcpy( &(a.sin_addr.s_addr), h->h_addr, sizeof(int));

         s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
         this->nsocket = -1;

         if (s==0) {
                  //WSACleanup();
                  bstr_resp = "Trouble creating socket";
                  *pResult = bstr_resp.Detach();
                  return S_OK;

         }
         if (connect(s, (struct sockaddr *)&a, sizeof(a))) {

                  //WSACleanup();
                  bstr_resp = "1:Trouble with socket: connecting";
                  *pResult = bstr_resp.Detach();
                  return S_OK;

         }

         u_long utemp = 1;

         // I always set non-blocking on my sockets and poll them via various mechanisms
         // blocking under windows is not the best for what we are trying to do
         if (ioctlsocket (s, FIONBIO, &utemp)) {

                  bstr_resp = "Trouble setting non-blocking";
                  *pResult = bstr_resp.Detach();
                  return S_OK;
                  
         } 

         bstr_resp = "";
         *pResult = bstr_resp.Detach();
         this->nsocket = s;
         this->bconnected = true;
         return S_OK;

}

// actual send procedure, attempt to send a string via connected socket
STDMETHODIMP CComSocket::socksend(BSTR *send_str, BSTR *resp)
{

         int nsize;

         CComBSTR bstr_resp;
         bstr_resp = "";

         TString in_string = _com_util::ConvertBSTRToString( *send_str ); 
         nsize = in_string.length();

         if ( send(nsocket, in_string.c_str(), nsize, 0) == SOCKET_ERROR) {
                  //WSACleanup();
                  bstr_resp = "Trouble with socket: connecting";
                  *resp = bstr_resp.Detach();
                  return S_OK;

         }

         *resp = bstr_resp.Detach();
         return S_OK;
}

// check to see if any socket info in buffer, if so receive it to pRecStr
STDMETHODIMP CComSocket::sockreceive(BSTR *pRecStr)
{

         CComBSTR bstr_resp;
         bstr_resp = "";
         int back;
         int back2;
   
         // we use msg_peek, note that 5000 is a good round number, for most higher speed
         // connections setting this to 1000 loses some speed, 5k seems about right per my testing
         back2 = recv(nsocket, myout, 5000, MSG_PEEK);

         // no data pending return
         if (back2 <= 0) {
                  *pRecStr = bstr_resp.Detach();
                  return S_OK;
         }

         // receive the pending data
         back = recv(nsocket, myout, 5000, 0);

         // was a good receive
         if (back2 > 0) {

                  // this should never, ever, ever happen, but nevertheless, good to code for it anyway
                  if (back2 > 5001) {
                           MessageBox(NULL, "string range write error in ccomsocket::sockreceive, critical error", "", 0);
                           bstr_resp = "string range write error in ccomsocket::sockreceive, critical error";
                           *pRecStr = bstr_resp.Detach();
                           return S_OK;
                  }

                  // set the NULL string terminator to string end
                  myout[back] = (char)0;

                  // luckily we can direct assign BSTR's like this
                  bstr_resp = myout;
                  *pRecStr = bstr_resp.Detach();
                  return S_OK;

         }

         // receive error
         if (back == -1) {
                  bstr_resp = "";
                  *pRecStr = bstr_resp.Detach();
                  this->breceive_error = true;
                  this->sreceive_error_text = "valid data pending actual call resulted in -1 receive code";
                  return S_OK;
         }

         // let's not treat this as an error, in a multi threading pull on the same socket this could
         // happen however it is very very very unlikely I think, some data was pending, but when we tried
         // to pull it, none was there, unlikely in the extreme but not really an error
         if (back == 0) {
                  bstr_resp = "";
                  *pRecStr = bstr_resp.Detach();
                  return S_OK;
         }

         // another type of error
         if (back == SOCKET_ERROR) {
                  bstr_resp = "";
                  *pRecStr = bstr_resp.Detach();
                  this->breceive_error = true;
                  this->sreceive_error_text = "valid data pending actual call resulted in SOCKET_ERROR receive code";
                  return S_OK;
         }

         bstr_resp = "";
         *pRecStr = bstr_resp.Detach();
         return S_OK;

}

// make sure sockets library is started
bool CComSocket::sockinit_a(void) {

         if (! wsa_inited) {

                  if (WSAStartup(0x101, &wsaData)) {
                           return ( false );
                  }
                  wsa_inited = 1;

         }
         return true;

}

// get the actual low level ID of this socket
STDMETHODIMP CComSocket::getsocketid(int *nCurSocket)
{

         *nCurSocket = nsocket;

         return S_OK;
}

// see if socket is connected or not
STDMETHODIMP CComSocket::isconnected(BOOL *bConn)
{

         if (this->bconnected) 
                  *bConn = TRUE;
         else
                  *bConn = FALSE;
         return S_OK;
}