Commit 610a811c authored by Julien Muchembled's avatar Julien Muchembled

sshd: fix generation of authorized_keys

parent 3bdc7451
...@@ -24,8 +24,7 @@ ...@@ -24,8 +24,7 @@
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# #
############################################################################## ##############################################################################
import os import errno, os
import itertools
from slapos.recipe.librecipe import GenericBaseRecipe from slapos.recipe.librecipe import GenericBaseRecipe
class KnownHostsFile(dict): class KnownHostsFile(dict):
...@@ -55,40 +54,6 @@ class KnownHostsFile(dict): ...@@ -55,40 +54,6 @@ class KnownHostsFile(dict):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self._dump() self._dump()
class AuthorizedKeysFile(object):
def __init__(self, filename):
self.filename = filename
def append(self, key):
"""Append the key to the file if the key's not in the file
"""
# Create the file it it does not exist
try:
file_ = os.open(self.filename, os.O_CREAT | os.O_EXCL)
os.close(file_)
except:
pass
with open(self.filename, 'r') as keyfile:
# itertools.imap avoid loading all the authorized_keys file in
# memory which would be counterproductive.
present = (key.strip() in itertools.imap(lambda k: k.strip(),
keyfile))
try:
keyfile.seek(-1, os.SEEK_END)
ended_by_newline = (keyfile.read() == '\n')
except IOError:
ended_by_newline = True
if not present:
with open(self.filename, 'a') as keyfile:
if not ended_by_newline:
keyfile.write('\n')
keyfile.write(key.strip())
class Recipe(GenericBaseRecipe): class Recipe(GenericBaseRecipe):
def install(self): def install(self):
...@@ -164,37 +129,29 @@ class Client(GenericBaseRecipe): ...@@ -164,37 +129,29 @@ class Client(GenericBaseRecipe):
return [wrapper] return [wrapper]
def keysplit(s):
"""
Split a string like "ssh-rsa AKLFKJSL..... ssh-rsa AAAASAF...."
and return the individual key_type + key strings.
TODO: handle comments in ssh keys, which are generated
by default at key creation.
"""
s = s.replace('\n', ' ')
si = iter(s.split(' '))
while True:
key_type = next(si)
if key_type == '':
continue
try:
key_value = next(si)
except StopIteration:
# odd number of elements, should not happen
break
yield '%s %s' % (key_type, key_value)
class AddAuthorizedKey(GenericBaseRecipe): class AddAuthorizedKey(GenericBaseRecipe):
def install(self): def install(self):
key = self.options['key']
ssh = self.createDirectory(self.options['home'], '.ssh') ssh = self.createDirectory(self.options['home'], '.ssh')
filename = os.path.join(ssh, 'authorized_keys')
authorized_keys = AuthorizedKeysFile(os.path.join(ssh, 'authorized_keys')) try:
for key in keysplit(self.options['key']): with open(filename) as f:
# XXX key might actually be the string 'None' or 'null' if f.read() == key:
authorized_keys.append(key) return [filename]
except IOError as e:
return [authorized_keys.filename] if e.errno != errno.ENOENT:
raise
# Atomic update.
tmp = filename + '.new'
try:
with open(tmp, 'w') as f:
f.write(key)
os.rename(tmp, filename)
finally:
try:
os.remove(tmp)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return [filename]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment