#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import sys
import os


def add_ssh_public_key(username="", ssh_key_data=""):
	if ssh_key_data.startswith("ssh-"):
		#openssh format
		return do_add_ssh_key(ssh_key_data)
	elif ssh_key_data.startswith("---- BEGIN"):
		ssh_key_data = convert_keydata(ssh_key_data)
		#temp file will be deleted on exit
		return do_add_ssh_key(ssh_key_data)
	else:
		#unknown key format...
		return " (unknown SSH public key format)"

def do_add_ssh_key(username="", ssh_key_data=""):
	"""
	Adds the given key to SSH's list of authorized keys.
	This method expects the key in openssh format
	"""
	if not ssh_key_data or not ssh_key_data.startswith("ssh-"):
		head = len(ssh_key_data)
		if head>100:
			head = 100
		return " (invalid SSH key data: %s...)" % (ssh_key_data[:head])
	#trim ending newline
	from winswitch.util.common import no_newlines
	ssh_key_data = no_newlines(ssh_key_data)
	#check for existing line
	ssh_authorized_keys = "~%s/.ssh/authorized_keys" % username
	ssh_autorized_keyfile = os.path.expanduser(ssh_authorized_keys)
	keyfile_exists = os.path.exists(ssh_autorized_keyfile)
	if keyfile_exists:
		f = None
		try:
			f = open(ssh_autorized_keyfile)
			keys = f.readlines()
			for test in keys:
				if test.startswith(ssh_key_data):
					return	" (SSH key already present in %s)" % ssh_authorized_keys
		finally:
			if f:
				f.close()
	ssh_dir = os.path.expanduser("~%s/.ssh/" % username)
	if not keyfile_exists and not os.path.exists(ssh_dir):
		os.mkdir(ssh_dir)
		os.fchmod(ssh_dir, oct(0700))
	f = None
	try:
		f = open(ssh_autorized_keyfile, 'a')
		f.write("%s\n" % ssh_key_data)
		if not keyfile_exists:
			os.fchmod(ssh_autorized_keyfile, oct(0600))
		return	" (SSH key added - password no longer needed)"
	finally:
		if f:
			f.close()

def convert_keydata(input_key_data):
	import tempfile
	#puttygen keyfile
	#must create a temp file (sigh):
	f = tempfile.NamedTemporaryFile(delete=True)
	f.write(input_key_data)
	f.flush()
	#then convert it:
	cmd = ["ssh-keygen", "-i", "-f", f.name]
	try:
		try:
			from winswitch.util.process_util import get_output
			code, ssh_key_data, _ = get_output(cmd)
			if code==0:
				return	""
			return ssh_key_data
		except:
			pass
		return	""
	finally:
		#temp file will be deleted on exit
		f.close()


def do_add_winswitch_client(key):
	"""
	Adds the given key to winswitch's list of authorized keys.
	"""
	try:
		from winswitch.util.auth import is_key_str_present, add_key_str, check_key_str
		if not  check_key_str(key):
			return "ERROR: invalid key string"
		import getpass
		username = getpass.getuser()
		info = "unknown@unknown"
		ssh_client = os.environ.get("SSH_CLIENT", None)
		if ssh_client:
			try:
				info = "unknown@%s" % (ssh_client.split(" ")[0])
			except Exception:
				pass
		if not is_key_str_present(username, key):
			add_key_str(username, key, info)
			return	"OK: KEY ADDED"
		else:
			return	"OK: KEY PRESENT"
	except Exception, e:
		import traceback
		traceback.print_exc()
		return	"ERROR: %s" % e



if __name__ == "__main__":
	""" This command may be called in the form:
		add_key WINSWITCHKEYDATA [SSHKEYDATA]
		The winswitch key will be added to .winswitch/client/authorized_keys (creating it if needed)
		The SSH key (optional) will be added to .ssh/authorized_keys (creating it if needed)
		It will printout the result of those operations.
	"""
	if len(sys.argv)!=2 and len(sys.argv)!=3:
		print("ERROR: WRONG NUMBER OF ARGUMENTS")
		sys.exit(1)

	from winswitch.util.simple_logger import set_log_to_file, set_log_to_tty
	set_log_to_file(False)
	set_log_to_tty(False)

	key = sys.argv[1]			#the winswitch public key
	ssh_info = ""				#result from the ssh key add
	if len(sys.argv)==3 and len(sys.argv[2])>0:
		#key data supplied, try to add it
		from winswitch.util.common import unescape_newlines
		ssh_key_data = unescape_newlines(sys.argv[2])
		ssh_info = add_ssh_public_key("", ssh_key_data)

	print(do_add_winswitch_client(key)+ssh_info)
