/////////////////////////////////////////////////////////////////////
//
// A small utility that checks whether authentication information
// is accepted by an ssh host for any OS close enough to posix
// to offer some needed facilities.
//
// The utility needs to be able to open a ptty, and to use ssh
// for the login to the host in a forked child process. It needs
// to intercept a signal as part of the process.
//
// From testing, this can be done on Linux, FreeBSD, NetBSD, MacOS,
// Solaris, and cygwin. Any changes should be tested on at least
// those systems.
//
/////////////////////////////////////////////////////////////////////

// Please adjust the following HAVE_ definitions to appropriately
// reflect the situation in your build environment.

#define HAVE_TERMIOS_H 1
#define HAVE_POSIX_OPENPT 1
#define HAVE_PSELECT 1

// includes for posix_openpt
#include <cstdlib>
// Change fd modes
#include <fcntl.h>
// For setsid()
#include <unistd.h>
// For signals
#include <signal.h>
// For pselect
#include <sys/select.h> // most systems
#include <sys/time.h> // Solaris 8
// For struct timespec
#include <time.h>
// For struct winsize
#include <sys/ioctl.h>
// For waitpid
#include <sys/wait.h>
#include <sys/types.h>
#ifdef HAVE_TERMIOS_H
#include <termios.h>
#endif // HAVE_TERMIOS_H
#include <string>
#include <cstring>
#include <cerrno>

#define GT_AUTH_BUF_LEN 60


/////////////////////////////////////////////////
// Not every platform has posix_openpt
// Provide an alternative.
// See linux man page for posix_openpt
/////////////////////////////////////////////////
#if !(HAVE_POSIX_OPENPT)
int
posix_openpt(int flags)
{
    return open("/dev/ptmx", flags);
}
#endif // HAVE_POSIX_OPENPT

/////////////////////////////////////////////////
// Solaris 5.8 and earlier do not have pselect
// This is a work around that should only be
// used if there is little to no chance of
// race problems.
// pselect is basically equivalent to doing this
// sequence of steps atomically.
/////////////////////////////////////////////////
#if !(HAVE_PSELECT)
// The two functions use different time structures
typedef struct timeval timestruct;
timestruct g_time;

int
pselect(int nfds, fd_set *readfds, fd_set *writefds,
        fd_set *exceptfds, const struct timestruct *timeout,
        const sigset_t *sigmask)
{
    int retval;
    sigset_t origmask;

   sigprocmask(SIG_SETMASK, sigmask, &origmask);
   retval = select(nfds, readfds,
                   writefds, exceptfds, timeout);
   sigprocmask(SIG_SETMASK, &origmask, NULL);

   return retval;
}
#else
typedef struct timespec timestruct;
timestruct g_time;
#endif


/////////////////////////////////////////////////
// Two static globals to hold file descriptors
// Easiest way for them to be visible to the
// signal handlers
/////////////////////////////////////////////////
static int g_master_fd;
static int g_tty_fd;


/////////////////////////////////////////////////
// A handler for SIGCHLD
// Doesn't need to do anything
/////////////////////////////////////////////////
void
sigchld_handler(int signum)
{
}


/////////////////////////////////////////////////
// A handler for SIGWINCH
/////////////////////////////////////////////////
void
resize_handler(int signum)
{
    struct winsize ttysize;

    if( ioctl( g_tty_fd, TIOCGWINSZ, &ttysize ) == 0 )
        ioctl( g_master_fd, TIOCGWINSZ, &ttysize );
}


/////////////////////////////////////////////////
// Create the pty, and get both master
// and slave file descriptors, as well
// as device information.
//
// Do not have a SIGCHLD handler
// installed when this runs. The behavior
// of grantpt() is undefined if a
// SIGCHLD handler is installed.
//
// The char* slave_device gets a pointer with
// the device name. This is in static storage,
// and must not be freed.
//
// Subsequent calls overwrite the value in
// static storage.
//
// Note that though failures of any of the pty
// operations set errno, we don't really care
// about the details of why it failed.
/////////////////////////////////////////////////
class pty_master_wrapper
{
public:
    pty_master_wrapper()
    {
        m_pty_created_successfully = false;
        m_master_fd = posix_openpt( O_RDWR | O_NOCTTY );
        m_local_slave_fd = -1;
        if( m_master_fd == -1 )
        {
            return;
        }

        fcntl( m_master_fd, F_SETFL, O_NONBLOCK );

        char* name_temp;

        if( grantpt( m_master_fd ) == -1 ||
            unlockpt( m_master_fd ) == -1 ||
          ( name_temp = ptsname( m_master_fd ) ) == NULL )
        {
            return;
        }
        // Keep a copy of the device name
        // If there are multiple threads producing pty's, this
        // could be vulnerable to a race condition
        // The return from ptsname is a global static, and is
        // overwritten on successive calls with different master_fd's
        m_pty_device.assign( name_temp );
        m_pty_created_successfully = true;
        return;
    }

    ~pty_master_wrapper()
    {
        if( m_local_slave_fd > 0 )
        {
            // Release the slave fd
            close( m_local_slave_fd );
        }
        if( m_master_fd > 0 )
        {
            // Release the master fd
            close( m_master_fd );
        }
        return;
    }

    void
    close_master()
    {
        if( m_master_fd > 0 )
        {
            close( m_master_fd );
            m_master_fd = 0;
        }
    }

    void
    close_slave()
    {
        if( m_local_slave_fd > 0 )
        {
            close( m_local_slave_fd );
            m_local_slave_fd = 0;
        }
    }

    int
    get_master_fd()
    {
        return m_master_fd;
    }

    std::string
    get_pty_device()
    {
        return m_pty_device;
    }

    int
    get_slave_fd()
    {
        return m_local_slave_fd;
    }

    bool
    pty_created()
    {
        return m_pty_created_successfully;
    }

    bool
    make_local_slave()
    {
        m_local_slave_fd = open( m_pty_device.c_str(), O_RDWR | O_NOCTTY );
        if( m_local_slave_fd == -1 )
        {
            return false;
        }
        return true;
    }

private:
    bool m_pty_created_successfully;
    int m_master_fd;
    std::string m_pty_device;
    int m_local_slave_fd;
};

/////////////////////////////////////////////////
// Handle a slave with a similar structure to
// how the master is handled.
// Notice that this makes the slave the
// controlling tty for the process
/////////////////////////////////////////////////
class pty_slave_wrapper
{
public:
    pty_slave_wrapper(std::string pty_device)
    {
        // Make it the controlling tty
        setsid();
        m_slave_created_successfully = false;
        m_slave_fd = open( pty_device.c_str(), O_RDWR );
        if( m_slave_fd == -1 )
            return;

        m_slave_created_successfully = true;
        m_device_name = pty_device;
        return;
    }

    ~pty_slave_wrapper()
    {
        if( m_slave_fd > 0 )
        {
            close( m_slave_fd );
        }
        return;
    }

    int
    get_slave_fd()
    {
        return m_slave_fd;
    }

    bool
    slave_created()
    {
        return m_slave_created_successfully;
    }

    std::string
    get_device_name()
    {
        return m_device_name;
    }

private:
    int m_slave_fd;
    bool m_slave_created_successfully;
    std::string m_device_name;
};

/////////////////////////////////////////////////
// Create an maintain a tty with a similar
// structure to that used for the pty above
// Note that it requires the pty be created
// first
/////////////////////////////////////////////////
class tty_wrapper
{
public:
    tty_wrapper( int l_master_fd )
    {
        m_tty_created_successfully = false;
        m_tty_fd = open( "/dev/tty", 0 );
        if( m_tty_fd == -1 )
            return;

        // Tie the pty and tty together
        struct winsize ttysize;
        if( ioctl( m_tty_fd, TIOCGWINSZ, &ttysize ) == 0 )
        {
            signal(SIGWINCH, resize_handler);
            ioctl( l_master_fd, TIOCSWINSZ, &ttysize );
        }
        m_tty_created_successfully = true;
        return;
    }

    ~tty_wrapper()
    {
        if( m_tty_fd > 0 )
        {
            close( m_tty_fd );
        }
        return;
    }

    int
    get_tty_fd()
    {
        return m_tty_fd;
    }

    bool
    tty_created()
    {
        return m_tty_created_successfully;
    }

private:
    int m_tty_fd;
    bool m_tty_created_successfully;
};

/////////////////////////////////////////////////
// Read arguments out of the passed environment
// Return false if this fails somehow
// There is a default value for the destination,
// but not for any other argument
/////////////////////////////////////////////////
bool
get_arguments_from_env( std::string& u_name,
                        std::string& pswrd,
                        std::string& dest )
{
    // Collect the three values of interest
    char* u_name_temp = std::getenv( "GT_USER_NAME" );
    if( u_name_temp == NULL )
    {
        return false;
    }
    u_name.assign( u_name_temp );
    char* pswrd_temp = std::getenv( "GT_PASSWORD" );
    if( pswrd_temp == NULL )
    {
        return false;
    }
    pswrd.assign( pswrd_temp );

    char* dest_temp = std::getenv( "GT_LOGIN_DEST" );
    if( dest_temp == NULL )
    {
        dest.assign( "localhost" );
    }
    else
    {
        dest.assign( dest_temp );
    }

    return true;
}

ssize_t
str_match( const std::string& target,
           const std::string& line )
{
    size_t pos = line.find( target );
    if( pos != std::string::npos )
    {
        // Should be no problem, since 0 <= pos < 60
        return (int)pos;
    }
    return -1;
}

/////////////////////////////////////////////////
// Actually try the password
// The file descriptor is for the master
/////////////////////////////////////////////////
void
try_pswrd( int fd, std::string& pswrd )
{
    ssize_t res;
    size_t count = 0;
    while( count < pswrd.length() )
    {
        res = write( fd, pswrd.c_str(), pswrd.length() );
        if( res < 0 )
        {
            exit( 9 );
        }
        count += res;
    }
    res = write( fd, "\n", 1 );
    if( res < 0 )
    {
        exit( 9 );
    }
}

/////////////////////////////////////////////////
// The file descriptor this gets is for the
// master
/////////////////////////////////////////////////
int
check_output( int fd, std::string& pswrd )
{
    // Check for typical responses to ssh attempts
    // Keep things in statics since this is called inside
    // a do-while loop
    static bool match_wrd_password = false;
    static int match_res;
    static std::string target_1("assword:"); // Ignore case for the 'P'
    static std::string target_2("try again"); //Failed password
    static std::string target_3("The authenticity of host "); // New host
    static int last_len_1 = 0;
    static int last_len_2 = 0;
    static char l_buf[2 * GT_AUTH_BUF_LEN] = {};
    char buf[GT_AUTH_BUF_LEN];
    ssize_t len = read( fd, buf, sizeof(buf) );
    if( len < 0 )
    {
        exit( 10 );
    }
    if( len == sizeof(buf) && ( last_len_1 != 0 || last_len_2 != 0 ) )
    {
        if( last_len_1 != 0 )
        {
            char temp_buf[GT_AUTH_BUF_LEN];
            strncpy( temp_buf, l_buf+last_len_2, last_len_1 );
            memset( l_buf, '\0', sizeof(l_buf) );
            strncpy( l_buf, temp_buf, last_len_1 );
            strncpy( l_buf+last_len_1, buf, len );
        }
        else
        {
            memset( l_buf, '\0', sizeof(l_buf) );
            strncpy( l_buf, buf, len );
        }
        last_len_2 = last_len_1;
        last_len_1 = len;
    }
    else
    {
        memset( l_buf, '\0', sizeof(l_buf) );
        strncpy( l_buf, buf, len );
        last_len_1 = len;
    }
    std::string l_read_line(l_buf);
    match_res = str_match( target_1, l_read_line );

    if( match_res != -1 )
    {
        if( !match_wrd_password )
        {
            try_pswrd( fd, pswrd );
            match_res = 0;
            match_wrd_password = true;
            return 0;
        }
        else
        {
            // It is asking for the password again - must be wrong
            return -1;
        }
    }
    // Check for failed password
    match_res = str_match( target_2, l_read_line );
    if( match_res != -1 )
    {
        return -1;
    }

    // Check for unknown host
    match_res = str_match( target_3, l_read_line );

    if( match_res != -1 )
    {
        return 2;
    }

    // Just ignore this line
    return 1;
}

int
main( int argc, char* argv[] )
{
    std::string u_name;
    std::string pswrd;
    std::string dest;
    if(!get_arguments_from_env(u_name, pswrd, dest))
    {
        exit( 11 );
    }

    pty_master_wrapper l_pty;
    if( !l_pty.pty_created() )
    {
        exit( 4 );
    }

    // pselect hangs on Linux if we don't make a local slave
    if( !l_pty.make_local_slave() )
    {
        exit( 5 );
    }

    std::string device_name = l_pty.get_pty_device();
    g_master_fd = l_pty.get_master_fd();

    // Install a SIGCHLD handler
    // Has to wait until after the pty has been made
    signal( SIGCHLD, sigchld_handler );

    tty_wrapper l_tty( l_pty.get_master_fd() );
    g_tty_fd = l_tty.get_tty_fd();

    // Fork a child process
    int child_pid = fork();
    if( child_pid == 0 ) //In the child
    {
        // Detach from the current tty and attach to a slave of the pty
        pty_slave_wrapper l_pty_slave( device_name );
        if( !l_pty_slave.slave_created() )
        {
            _exit( 6 );
        }

        // Construct command to run
        std::string l_arg_ssh( "ssh" );
        std::string l_arg_1 = u_name;
        l_arg_1.append( "@" );
        l_arg_1.append( dest );
        std::string l_arg_o( "-o" ); // For specifying options
        std::string l_arg_cra( "ChallengeResponseAuthentication=yes" );
        std::string l_arg_pswrd( "PasswordAuthentication=yes" );
        std::string l_arg_pref( "PreferredAuthentications=" );
        l_arg_pref.append( "keyboard-interactive,password,");
        std::string l_arg_2( "echo" );
        std::string l_arg_3( "\"LoggedIn\"" );
        const char* l_argv[] = { l_arg_ssh.c_str(),
                                 l_arg_o.c_str(),
                                 l_arg_cra.c_str(),
                                 l_arg_o.c_str(),
                                 l_arg_pswrd.c_str(),
                                 l_arg_o.c_str(),
                                 l_arg_pref.c_str(),
                                 l_arg_1.c_str(),
                                 l_arg_2.c_str(),
                                 l_arg_3.c_str(),
                                 NULL
                                };
        execvp( l_arg_ssh.c_str(), (char* const*)l_argv );

        // Can only be reached if execvp fails
        _exit( 7 );
    }
    else if( child_pid < 0 ) // fork failed
    {
        exit( 8 );
    }
    // Otherwise, we are in the parent process
    // Set up to capture the signal from the child
    sigset_t sigmask_regular;
    sigset_t sigmask_select;
    sigemptyset( &sigmask_regular );
    sigaddset( &sigmask_regular, SIGCHLD );
    sigprocmask( SIG_SETMASK, &sigmask_regular, NULL );
    sigemptyset( &sigmask_select );

    bool completed = false;
    bool pswrd_passed = false;
    pid_t id_for_wait = -1;
    int res;
    int status = -1;
    int ret_val = -1;
    g_time.tv_sec = 5;
#if !(HAVE_PSELECT)
    g_time.tv_usec = 0;
#else
    g_time.tv_nsec = 0;
#endif

    // Finely crafted unexpected results from ssh can lead to
    // a non-terminating loop - protect against that
    // max_count is far larger than anything that should happen legitimately
    size_t loop_count = 0;
    size_t max_count = 100;
    do
    {
        ++loop_count;
        if( loop_count > max_count )
        {
            completed = true;
        }
        if( !completed )
        {
            fd_set l_fd;
            FD_ZERO( &l_fd );
            FD_SET( g_master_fd, &l_fd );

            res = pselect( g_master_fd+1, &l_fd,
                           NULL, NULL, &g_time,
                           &sigmask_select );

            if( res > 0 )
            {
                if( FD_ISSET( g_master_fd, &l_fd ) )
                {
                    // Return values
                    // 0 means password was passed
                    // 1 means line we ignore
                    // 2 means ssh can't authenticate host
                    // -1 means password failure of some sort
                    res = check_output( g_master_fd, pswrd );
                    if( res == 0 )
                    {
                        if( pswrd_passed )
                        {
                            ret_val = 2;
                            completed = true;
                        }
                        pswrd_passed = true;
                    }
                    else if( res == 1 )
                    {
                        if( pswrd_passed )
                        {
                            ret_val = 0;
                            completed = true;
                        }
                    }
                    else if( res == 2 )
                    {
                        ret_val = 3;
                    }
                    else if( res == -1 )
                    {
                        // password failed
                        l_pty.close_master();
                        l_pty.close_slave();
                        completed = true;
                        ret_val = 2;
                    }
                }
            }
            // Wait without blocking - thus the loop
            id_for_wait = waitpid( child_pid, &status, WNOHANG);
        }
        else
        {
            for(int i = 0; i < 10; ++i)
            {
                sleep(2);
                id_for_wait = waitpid( child_pid, &status, WNOHANG );
                if(id_for_wait != 0)
                    break;
            }
            if( id_for_wait == 0 )
            {
                killpg( child_pid, SIGTERM );
                ret_val = 2;
            }
        }
    } while( id_for_wait == 0 ||
             (!WIFEXITED( status ) &&
              !WIFSIGNALED( status ) ) );

    return ret_val;
}

