/*/////////////////////////////////////////////////////////////////////////////
Name: Bo Bayles
Assignment: CS 285, FTP Project
File: server.cpp
Compile: g++ -o server server.cpp
/*/////////////////////////////////////////////////////////////////////////////

/*/////////////////////////////////////////////////////////////////////////////
    Compiler Directives
/////////////////////////////////////////////////////////////////////////////*/
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <netinet/in.h>
#include <netdb.h>
#include <signal.h>
#include <string>
#include <sys/types.h>
#include <sys/socket.h>
using namespace std;

#define SERVER_PORT 2221

/*/////////////////////////////////////////////////////////////////////////////
    Data type definitions
/////////////////////////////////////////////////////////////////////////////*/
volatile int controlSocket, socketDescriptor;
//Declared volatile so it can be used in the interrupt handler safely

/*/////////////////////////////////////////////////////////////////////////////
    Function declarations
/////////////////////////////////////////////////////////////////////////////*/
bool checkFile( string &fileName );
void connection(int controlSocket);
void decryptPassword(string &username, string &password);
void encryptFile(string &filename, string &password);
void interruptTrap(int signal);
bool validCredentials(string &username, string &password);
int makeDataSocket( const int controlSocket,
                    const int dataPort,
                    int &dataSocket,
                    sockaddr_in &client_addr );
int sendFile( int &dataSocket,
              sockaddr_in &client_addr,
              string &fileName,
              string &password );

/*/////////////////////////////////////////////////////////////////////////////
    Function definitions
/////////////////////////////////////////////////////////////////////////////*/
int main()
{
  signal(SIGINT, interruptTrap);
  //Register interrup signal handler

  sockaddr_in serverAddress = { AF_INET, htons(SERVER_PORT) };
  sockaddr_in clientAddress = { AF_INET };
  int client_len = sizeof(clientAddress);
  //Set up connection parameters with Berkeley Sockets API

  socketDescriptor = socket( AF_INET, SOCK_STREAM, 0 );
  if( socketDescriptor == -1 )
  {
    cout << "Server: socket() failed"  << endl
         << "Exiting...\n" << endl;
    exit(1) ;
  }
  //Create a stream socket with socket()
  cout << "Server: Socket " << socketDescriptor << " was created." << endl;

  if( bind( socketDescriptor,
           (struct sockaddr*)&serverAddress,
           sizeof(serverAddress) ) == -1 )
  {
    cout << "Server: bind() failed"  << endl
    << "Exiting...\n" << endl;
    exit(1);
  }
  //Bind the socket to an internet port with bind()
  cout << "Server: bind() was successful." << endl;

  if( listen( socketDescriptor, 1 ) == -1 )
  {
    cout << "Server: listen() failed"  << endl
    << "Exiting...\n" << endl;
    exit(1);
  }
  //Listen for clients with listen()
  cout << "Server: Listening for clients..." << endl;


  while(1)
  //Wait for a client to accept
  {
    controlSocket = accept( socketDescriptor,
                           (struct sockaddr*)&clientAddress,
                           (socklen_t*)&client_len );
    if(controlSocket < 0)
    {
      cout << "Server: accept() failed" << endl
           << "Exiting...\n" << endl;
      exit(1);
    }
    connection(controlSocket);  //Hand off to the connection function
  }
  
  cout << "Exiting...\n" << endl;
  return 0;
}

void connection(int controlSocket)
{
  bool loggedIn = false;
  char msgBuffer[512];
  int dataSocket = -1;
  int dataPort = 0;
  int k;
  string message, username, password, tempString;
  struct sockaddr_in client_addr;

  cout << "Server: Client " << controlSocket << " has connected." << endl;

  while(1)
  {
    k = read( controlSocket, msgBuffer, sizeof(msgBuffer) );
    //Read from the client
    if(k == 0) //Go until the connection is closed
      break;

    cout << "Client: " << msgBuffer << endl;
    //Report what the client said

/*/////////////////////////////////////////////////////////////////////////////
    USER Command
/////////////////////////////////////////////////////////////////////////////*/
    if( !strncmp(msgBuffer, "USER", 4) )
    {
      tempString = msgBuffer;
      username = tempString.erase(0,5); //Remove the "USER" bit
      loggedIn = false; //Log the user out
      password = '\0'; //Invalidate the password, just in case.

      message = "300 - Send password for " + username;
      cout << "Server: " << message << endl;
      write( controlSocket, message.c_str(), message.size()+1 );
      //Report the new user name
    }
/*/////////////////////////////////////////////////////////////////////////////
    PASS Command
/////////////////////////////////////////////////////////////////////////////*/
    else if( !strncmp(msgBuffer, "PASS", 4) )
    {
      string newTemp ( msgBuffer, k-1 );
      //The client is sending an encrypted password, which may contain null
      //  characters. So, we use the string(char*, int) constructor, which
      //  allows for all characters, including null, to be copied into a
      //  new string.
      password = newTemp.erase(0,5); //Remove the "PASS" bit
      decryptPassword(username, password); //Convert the ciphertext to plain

      cout << "Server: Plaintext password is '" << password << "'" <<   endl;

      if(!validCredentials(username, password) )
      {
        message = "400 - Password does not match username.";
        cout << "Server: " << message << endl;
        write( controlSocket, message.c_str(), message.size()+1 );
        //Report the failure
      }
      else
      {
        message = "200 - " + username + " logged in successfully.";
        cout << "Server: " << message << endl;
        write( controlSocket, message.c_str(), message.size()+1 );
        //Report the success
        loggedIn = true;
      }
    }
/*/////////////////////////////////////////////////////////////////////////////
    PORT Command
/////////////////////////////////////////////////////////////////////////////*/
    else if( !strncmp(msgBuffer, "PORT", 4) && loggedIn )
    {
      tempString = msgBuffer;
      tempString.erase(0,5); //Remove the "PORT" bit
      dataPort = atoi( tempString.c_str() ); //Get the poort number

      makeDataSocket(controlSocket, dataPort, dataSocket, client_addr);
      if ( dataSocket > 0 )
      {
        message = "200 - Data socket created.";
        cout << "Server: " << message << endl;
        write( controlSocket, message.c_str(), message.size()+1 );
      }
      //Report the success
      else
      {
        message = "400 - Data socket could not be created.";
        cout << "Server: " << message << endl;
        write( controlSocket, message.c_str(), message.size()+1 );
      }
    }
/*/////////////////////////////////////////////////////////////////////////////
    RETR Command
/////////////////////////////////////////////////////////////////////////////*/
    else if( !strncmp(msgBuffer, "RETR", 4) && loggedIn )
    {
      tempString = msgBuffer;
      tempString.erase(0,5); //Remove the "RETR" bit
      if(dataPort != 0)
      {
        if( !checkFile(tempString) )
        {
          message = "400 - File does not exist.";
          cout << "Server: " << message << endl;
          write( controlSocket, message.c_str(), message.size()+1 );
          //Report the failure
        }
        else
        {
          message = "200 - File to begin.";
          cout << "Server: " << message << endl;
          write( controlSocket, message.c_str(), message.size()+1 );
          
          sleep(2);
          int r = sendFile( dataSocket, client_addr, tempString, password );

          if(r != 1)
            message = "200 - File transfer completed succesfully.";
          else
            message = "400 - Failure sending file.";
          cout << "Server: " << message << endl;
          write( controlSocket, message.c_str(), message.size()+1 );
          dataPort = 0;
        }
      }
      else
      {
        message = "400 - Cannot retrieve file without first sending PORT.";
        cout << "Server: " << message << endl;
        write( controlSocket, message.c_str(), message.size()+1 );
        //Report the failure
      }
    }
/*/////////////////////////////////////////////////////////////////////////////
    QUIT Command
/////////////////////////////////////////////////////////////////////////////*/
    else if( !strncmp(msgBuffer, "QUIT", 4) && loggedIn )
    {
      message = "200 - QUIT received. Disconnecting client.";
      cout << "Server: " << message << endl;
      write( controlSocket, message.c_str(), message.size()+1 );

      break;
    }
/*/////////////////////////////////////////////////////////////////////////////
    Invalid command
/////////////////////////////////////////////////////////////////////////////*/
    else
    {
      message = "400 - Invalid command.";
      cout << "Server: " << message << endl;
      write( controlSocket, message.c_str(), message.size()+1 );
      //Report the failure
    }

  } //End infinite loop.

  close(dataSocket); //Close the data connection if it isn't already
  close(controlSocket);

  return;
}

void encryptFile(string &fileName, string &password)
{
  string argument = "cp \"" + fileName + "\" server_cipher.tmp";
  system( argument.c_str() );
  
  argument = "openssl des3 -in \"" + 
                    fileName +
                    "\" -out server_cipher.tmp -k " +
                    password;
  //Call out to OpenSSL, which will do the DES3 encryption for us.
  //cout << "SYSTEM: " << argument << endl;
  system( argument.c_str() );
  
  return;
}

void decryptPassword(string &username, string &password)
{
  ifstream tempPlain;
  ofstream tempCipher;
  char buffer[256];
  
  tempCipher.open("server_cipher.tmp", ios::trunc);
  tempCipher << password; //Save the ciphertext password to server_cipher.tmp
  tempCipher.close();
  
  string arg =
    "openssl des3 -d -in server_cipher.tmp -out server_plain.tmp -k " +
    username;
  cout << "SYSTEM: " << arg << endl;
  system( arg.c_str() );
  //Call out to decrypt the password to server_plain.tmp
  
  tempPlain.open("server_plain.tmp");
  tempPlain.getline( buffer, 256 );
  password = buffer;
  tempPlain.close();
  
  system("rm -rf server_cipher.tmp");
  system("rm -rf server_plain.tmp");
  //Remove out temporary file

  return;
}

bool validCredentials(string &username, string &password)
{
  bool loggedIn = false; //Logged out to start
  ifstream inFile;
  string check;
  inFile.open("db.txt"); //Password database is db.txt
 
  while( inFile.good() ) //Make sure it opened
  {
    inFile >> check; //Check the username
    if ( !strcmp( check.c_str(), username.c_str() ) ) //Match?
    {
      inFile >> check; //Check the password if it matched
      if( !strcmp( check.c_str(), password.c_str() ) ) //Match?
      {
        loggedIn = true;  //Logged in!
      }
      else
      {
        cout << "Wrong password. Nice try, buster." << endl;
        //Good username, bad password.
      }
      break;
    }
    else
    {
      inFile >> check;
      check = '\0'; //Ignore the password if the user wasn't valid
    }
  }
  if( loggedIn )
    cout << "Server: " << endl << username << " logged in with " << password << endl;
  else
    cout << endl << "Server: " << username << " not found in database." << endl;
  inFile.close();
  
  return loggedIn;
}

int makeDataSocket( const int controlSocket,
                    const int dataPort,
                    int &dataSocket,
                    sockaddr_in &client_addr )
{
  cout << "Server: Attempting to connect to client's Data socket " << dataPort << "." << endl;

  client_addr.sin_family = AF_INET;
  //Holder for client information
  socklen_t length = sizeof(client_addr);
  getpeername(controlSocket, (struct sockaddr*)&client_addr, &length);
  //Get client information from control connection
  client_addr.sin_port = htons(dataPort);
  //Replace control connection's port with the data connection's

  if( ( dataSocket = socket( AF_INET, SOCK_STREAM, 0 ) ) == -1 )
  {
    cout << "Server: Data socket creation failed." << endl;
    return -1;
  }
  //Get the host

  return 0;
}

bool checkFile( string &fileName )
{
  FILE *pFile = fopen(fileName.c_str(), "r");
  if( pFile != NULL)
  {
    fclose(pFile);
    return true;
  }
  else
    return false;
}

int sendFile( int &dataSocket, sockaddr_in &client_addr , string &fileName, string &password )
{
  bool finished = false;
  char buffer[50];
  if( connect( dataSocket,
     (struct sockaddr*)&client_addr,
      sizeof(client_addr) ) == -1 )
  {
    cout << "Server: Data connection failed." << endl;
    close(dataSocket);
    return 1;
  }
  cout << "Server: Connected to client's data socket" << endl;
  //Obtain a socket file descriptor

/*/////////////////////////////////////////////////////////////////////////////
    File transfer
/////////////////////////////////////////////////////////////////////////////*/
  ifstream inFile;
  encryptFile(fileName, password);
  inFile.open("server_cipher.tmp"); //Then open it as an ifstream

  int k = 0;
  while( inFile.good() ) //Continue reading until the end of the file is hit
  {
    inFile.read( buffer, sizeof(buffer)  ); //Copy a chunk
    k = inFile.gcount(); //How many bytes were copied?
    buffer[ k ] = '\0'; //Make the chunk into a string
    //cout << buffer << endl; //Output what was sent
    write( dataSocket, buffer, k );
  }

  inFile.close();
  system("rm -rf server_cipher.tmp");

  close(dataSocket); //Close the connection

  return 0;
}

void interruptTrap(int signal)
{
  string message = "666 - Server is shutting down.";
   
  cout << "Server: Interrupt received." << endl;
  
  write( controlSocket, message.c_str(), message.size() );
  
  sleep(5); //Wait a few seconds.

  close(controlSocket);
  close(socketDescriptor);
  //Close the connections
  
  exit( 0 );
  
  return;
}

