fix: code style / flake8 automation
ZeroDownTime/CloudBender/pipeline/head This commit looks good Details

This commit is contained in:
Stefan Reimer 2022-02-22 11:04:29 +01:00
parent 129d287ae5
commit 7d6135e099
12 changed files with 793 additions and 439 deletions

3
.flake8 Normal file
View File

@ -0,0 +1,3 @@
[flake8]
extend-ignore = E501
exclude = .git,__pycache__,build,dist,report

View File

@ -20,7 +20,7 @@ dev_setup:
pip install -r dev-requirements.txt --user pip install -r dev-requirements.txt --user
pytest: pytest:
flake8 --ignore=E501 cloudbender tests flake8 cloudbender tests
TEST=True pytest --log-cli-level=DEBUG TEST=True pytest --log-cli-level=DEBUG
clean: clean:

View File

@ -17,4 +17,4 @@ class NullHandler(logging.Handler): # pragma: no cover
pass pass
logging.getLogger('cloudbender').addHandler(NullHandler()) logging.getLogger("cloudbender").addHandler(NullHandler())

View File

@ -12,6 +12,7 @@ from .utils import setup_logging
from .exceptions import InvalidProjectDir from .exceptions import InvalidProjectDir
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,8 +28,8 @@ def cli(ctx, debug, directory):
if directory: if directory:
if not os.path.isabs(directory): if not os.path.isabs(directory):
directory = os.path.normpath(os.path.join(os.getcwd(), directory)) directory = os.path.normpath(os.path.join(os.getcwd(), directory))
elif os.getenv('CLOUDBENDER_PROJECT_ROOT'): elif os.getenv("CLOUDBENDER_PROJECT_ROOT"):
directory = os.getenv('CLOUDBENDER_PROJECT_ROOT') directory = os.getenv("CLOUDBENDER_PROJECT_ROOT")
else: else:
directory = os.getcwd() directory = os.getcwd()
@ -50,7 +51,7 @@ def cli(ctx, debug, directory):
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.pass_obj @click.pass_obj
def render(cb, stack_names, multi): def render(cb, stack_names, multi):
""" Renders template and its parameters - CFN only""" """Renders template and its parameters - CFN only"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
_render(stacks) _render(stacks)
@ -61,7 +62,7 @@ def render(cb, stack_names, multi):
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.pass_obj @click.pass_obj
def sync(cb, stack_names, multi): def sync(cb, stack_names, multi):
""" Renders template and provisions it right away """ """Renders template and provisions it right away"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
@ -74,7 +75,7 @@ def sync(cb, stack_names, multi):
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.pass_obj @click.pass_obj
def validate(cb, stack_names, multi): def validate(cb, stack_names, multi):
""" Validates already rendered templates using cfn-lint - CFN only""" """Validates already rendered templates using cfn-lint - CFN only"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
for s in stacks: for s in stacks:
@ -86,11 +87,17 @@ def validate(cb, stack_names, multi):
@click.command() @click.command()
@click.argument("stack_names", nargs=-1) @click.argument("stack_names", nargs=-1)
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.option("--include", default='.*', help="regex matching wanted outputs, default '.*'") @click.option(
@click.option("--values", is_flag=True, help="Only output values, most useful if only one outputs is returned") "--include", default=".*", help="regex matching wanted outputs, default '.*'"
)
@click.option(
"--values",
is_flag=True,
help="Only output values, most useful if only one outputs is returned",
)
@click.pass_obj @click.pass_obj
def outputs(cb, stack_names, multi, include, values): def outputs(cb, stack_names, multi, include, values):
""" Prints all stack outputs """ """Prints all stack outputs"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
for s in stacks: for s in stacks:
@ -110,7 +117,7 @@ def outputs(cb, stack_names, multi, include, values):
@click.option("--graph", is_flag=True, help="Create Dot Graph file") @click.option("--graph", is_flag=True, help="Create Dot Graph file")
@click.pass_obj @click.pass_obj
def create_docs(cb, stack_names, multi, graph): def create_docs(cb, stack_names, multi, graph):
""" Parses all documentation fragments out of rendered templates creating docs/*.md file """ """Parses all documentation fragments out of rendered templates creating docs/*.md file"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
for s in stacks: for s in stacks:
@ -122,7 +129,7 @@ def create_docs(cb, stack_names, multi, graph):
@click.argument("change_set_name") @click.argument("change_set_name")
@click.pass_obj @click.pass_obj
def create_change_set(cb, stack_name, change_set_name): def create_change_set(cb, stack_name, change_set_name):
""" Creates a change set for an existing stack - CFN only""" """Creates a change set for an existing stack - CFN only"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
@ -133,29 +140,33 @@ def create_change_set(cb, stack_name, change_set_name):
@click.argument("stack_name") @click.argument("stack_name")
@click.pass_obj @click.pass_obj
def refresh(cb, stack_name): def refresh(cb, stack_name):
""" Refreshes Pulumi stack / Drift detection """ """Refreshes Pulumi stack / Drift detection"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
if s.mode == 'pulumi': if s.mode == "pulumi":
s.refresh() s.refresh()
else: else:
logger.info('{} uses Cloudformation, refresh skipped.'.format(s.stackname)) logger.info("{} uses Cloudformation, refresh skipped.".format(s.stackname))
@click.command() @click.command()
@click.argument("stack_name") @click.argument("stack_name")
@click.option("--reset", is_flag=True, help="All pending stack operations are removed and the stack will be re-imported") @click.option(
"--reset",
is_flag=True,
help="All pending stack operations are removed and the stack will be re-imported",
)
@click.pass_obj @click.pass_obj
def export(cb, stack_name, reset=False): def export(cb, stack_name, reset=False):
""" Exports a Pulumi stack to repair state """ """Exports a Pulumi stack to repair state"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
if s.mode == 'pulumi': if s.mode == "pulumi":
s.export(reset) s.export(reset)
else: else:
logger.info('{} uses Cloudformation, export skipped.'.format(s.stackname)) logger.info("{} uses Cloudformation, export skipped.".format(s.stackname))
@click.command() @click.command()
@ -165,7 +176,7 @@ def export(cb, stack_name, reset=False):
@click.option("--secret", is_flag=True, help="Value is a secret") @click.option("--secret", is_flag=True, help="Value is a secret")
@click.pass_obj @click.pass_obj
def set_config(cb, stack_name, key, value, secret=False): def set_config(cb, stack_name, key, value, secret=False):
""" Sets a config value, encrypts with stack key if secret """ """Sets a config value, encrypts with stack key if secret"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
@ -177,7 +188,7 @@ def set_config(cb, stack_name, key, value, secret=False):
@click.argument("key") @click.argument("key")
@click.pass_obj @click.pass_obj
def get_config(cb, stack_name, key): def get_config(cb, stack_name, key):
""" Get a config value, decrypted if secret """ """Get a config value, decrypted if secret"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
@ -188,14 +199,18 @@ def get_config(cb, stack_name, key):
@click.argument("stack_name") @click.argument("stack_name")
@click.pass_obj @click.pass_obj
def preview(cb, stack_name): def preview(cb, stack_name):
""" Preview of Pulumi stack up operation """ """Preview of Pulumi stack up operation"""
stacks = _find_stacks(cb, [stack_name]) stacks = _find_stacks(cb, [stack_name])
for s in stacks: for s in stacks:
if s.mode == 'pulumi': if s.mode == "pulumi":
s.preview() s.preview()
else: else:
logger.warning('{} uses Cloudformation, use create-change-set for previews.'.format(s.stackname)) logger.warning(
"{} uses Cloudformation, use create-change-set for previews.".format(
s.stackname
)
)
@click.command() @click.command()
@ -203,7 +218,7 @@ def preview(cb, stack_name):
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.pass_obj @click.pass_obj
def provision(cb, stack_names, multi): def provision(cb, stack_names, multi):
""" Creates or updates stacks or stack groups """ """Creates or updates stacks or stack groups"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
_provision(cb, stacks) _provision(cb, stacks)
@ -214,7 +229,7 @@ def provision(cb, stack_names, multi):
@click.option("--multi", is_flag=True, help="Allow more than one stack to match") @click.option("--multi", is_flag=True, help="Allow more than one stack to match")
@click.pass_obj @click.pass_obj
def delete(cb, stack_names, multi): def delete(cb, stack_names, multi):
""" Deletes stacks or stack groups """ """Deletes stacks or stack groups"""
stacks = _find_stacks(cb, stack_names, multi) stacks = _find_stacks(cb, stack_names, multi)
# Reverse steps # Reverse steps
@ -235,16 +250,16 @@ def delete(cb, stack_names, multi):
@click.command() @click.command()
@click.pass_obj @click.pass_obj
def clean(cb): def clean(cb):
""" Deletes all previously rendered files locally """ """Deletes all previously rendered files locally"""
cb.clean() cb.clean()
def sort_stacks(cb, stacks): def sort_stacks(cb, stacks):
""" Sort stacks by dependencies """ """Sort stacks by dependencies"""
data = {} data = {}
for s in stacks: for s in stacks:
if s.mode == 'pulumi': if s.mode == "pulumi":
data[s.id] = set() data[s.id] = set()
continue continue
@ -253,10 +268,14 @@ def sort_stacks(cb, stacks):
deps = [] deps = []
for d in s.dependencies: for d in s.dependencies:
# For now we assume deps are artifacts so we prepend them with our local profile and region to match stack.id # For now we assume deps are artifacts so we prepend them with our local profile and region to match stack.id
for dep_stack in cb.filter_stacks({'region': s.region, 'profile': s.profile, 'provides': d}): for dep_stack in cb.filter_stacks(
{"region": s.region, "profile": s.profile, "provides": d}
):
deps.append(dep_stack.id) deps.append(dep_stack.id)
# also look for global services # also look for global services
for dep_stack in cb.filter_stacks({'region': 'global', 'profile': s.profile, 'provides': d}): for dep_stack in cb.filter_stacks(
{"region": "global", "profile": s.profile, "provides": d}
):
deps.append(dep_stack.id) deps.append(dep_stack.id)
data[s.id] = set(deps) data[s.id] = set(deps)
@ -267,7 +286,9 @@ def sort_stacks(cb, stacks):
v.discard(k) v.discard(k)
if data: if data:
extra_items_in_deps = functools.reduce(set.union, data.values()) - set(data.keys()) extra_items_in_deps = functools.reduce(set.union, data.values()) - set(
data.keys()
)
data.update({item: set() for item in extra_items_in_deps}) data.update({item: set() for item in extra_items_in_deps})
while True: while True:
@ -283,41 +304,46 @@ def sort_stacks(cb, stacks):
result.append(s) result.append(s)
yield result yield result
data = {item: (dep - ordered) for item, dep in data.items() data = {
if item not in ordered} item: (dep - ordered) for item, dep in data.items() if item not in ordered
}
assert not data, "A cyclic dependency exists amongst %r" % data assert not data, "A cyclic dependency exists amongst %r" % data
def _find_stacks(cb, stack_names, multi=False): def _find_stacks(cb, stack_names, multi=False):
""" search stacks by name """ """search stacks by name"""
stacks = [] stacks = []
for s in stack_names: for s in stack_names:
stacks = stacks + cb.resolve_stacks(s) stacks = stacks + cb.resolve_stacks(s)
if not multi and len(stacks) > 1: if not multi and len(stacks) > 1:
logger.error('Found more than one stack matching name ({}). Please set --multi if that is what you want.'.format(', '.join(stack_names))) logger.error(
"Found more than one stack matching name ({}). Please set --multi if that is what you want.".format(
", ".join(stack_names)
)
)
raise click.Abort() raise click.Abort()
if not stacks: if not stacks:
logger.error('Cannot find stack matching: {}'.format(', '.join(stack_names))) logger.error("Cannot find stack matching: {}".format(", ".join(stack_names)))
raise click.Abort() raise click.Abort()
return stacks return stacks
def _render(stacks): def _render(stacks):
""" Utility function to reuse code between tasks """ """Utility function to reuse code between tasks"""
for s in stacks: for s in stacks:
if s.mode != 'pulumi': if s.mode != "pulumi":
s.render() s.render()
s.write_template_file() s.write_template_file()
else: else:
logger.info('{} uses Pulumi, render skipped.'.format(s.stackname)) logger.info("{} uses Pulumi, render skipped.".format(s.stackname))
def _provision(cb, stacks): def _provision(cb, stacks):
""" Utility function to reuse code between tasks """ """Utility function to reuse code between tasks"""
for step in sort_stacks(cb, stacks): for step in sort_stacks(cb, stacks):
if step: if step:
with ThreadPoolExecutor(max_workers=len(step)) as group: with ThreadPoolExecutor(max_workers=len(step)) as group:
@ -348,5 +374,5 @@ cli.add_command(set_config)
cli.add_command(get_config) cli.add_command(get_config)
cli.add_command(export) cli.add_command(export)
if __name__ == '__main__': if __name__ == "__main__":
cli(obj={}) cli(obj={})

View File

@ -10,7 +10,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BotoConnection(): class BotoConnection:
_sessions = {} _sessions = {}
_clients = {} _clients = {}
@ -27,13 +27,15 @@ class BotoConnection():
# Change the cache path from the default of ~/.aws/boto/cache to the one used by awscli # Change the cache path from the default of ~/.aws/boto/cache to the one used by awscli
session_vars = {} session_vars = {}
if profile: if profile:
session_vars['profile'] = (None, None, profile, None) session_vars["profile"] = (None, None, profile, None)
if region and region != 'global': if region and region != "global":
session_vars['region'] = (None, None, region, None) session_vars["region"] = (None, None, region, None)
session = botocore.session.Session(session_vars=session_vars) session = botocore.session.Session(session_vars=session_vars)
cli_cache = os.path.join(os.path.expanduser('~'), '.aws/cli/cache') cli_cache = os.path.join(os.path.expanduser("~"), ".aws/cli/cache")
session.get_component('credential_provider').get_provider('assume-role').cache = credentials.JSONFileCache(cli_cache) session.get_component("credential_provider").get_provider(
"assume-role"
).cache = credentials.JSONFileCache(cli_cache)
self._sessions[(profile, region)] = session self._sessions[(profile, region)] = session
@ -41,7 +43,9 @@ class BotoConnection():
def _get_client(self, service, profile=None, region=None): def _get_client(self, service, profile=None, region=None):
if self._clients.get((profile, region, service)): if self._clients.get((profile, region, service)):
logger.debug("Reusing boto session for {} {} {}".format(profile, region, service)) logger.debug(
"Reusing boto session for {} {} {}".format(profile, region, service)
)
return self._clients[(profile, region, service)] return self._clients[(profile, region, service)]
session = self._get_session(profile, region) session = self._get_session(profile, region)
@ -59,8 +63,12 @@ class BotoConnection():
return getattr(client, command)(**kwargs) return getattr(client, command)(**kwargs)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'Throttling': if e.response["Error"]["Code"] == "Throttling":
logger.warning("Throttling exception occured during {} - retry after 3s".format(command)) logger.warning(
"Throttling exception occured during {} - retry after 3s".format(
command
)
)
time.sleep(3) time.sleep(3)
pass pass
else: else:

View File

@ -9,7 +9,8 @@ logger = logging.getLogger(__name__)
class CloudBender(object): class CloudBender(object):
""" Config Class to handle recursive conf/* config tree """ """Config Class to handle recursive conf/* config tree"""
def __init__(self, root_path): def __init__(self, root_path):
self.root = pathlib.Path(root_path) self.root = pathlib.Path(root_path)
self.sg = None self.sg = None
@ -20,28 +21,39 @@ class CloudBender(object):
"hooks_path": self.root.joinpath("hooks"), "hooks_path": self.root.joinpath("hooks"),
"docs_path": self.root.joinpath("docs"), "docs_path": self.root.joinpath("docs"),
"outputs_path": self.root.joinpath("outputs"), "outputs_path": self.root.joinpath("outputs"),
"artifact_paths": [self.root.joinpath("artifacts")] "artifact_paths": [self.root.joinpath("artifacts")],
} }
if not self.ctx['config_path'].is_dir(): if not self.ctx["config_path"].is_dir():
raise InvalidProjectDir("Check '{0}' exists and is a valid CloudBender project folder.".format(self.ctx['config_path'])) raise InvalidProjectDir(
"Check '{0}' exists and is a valid CloudBender project folder.".format(
self.ctx["config_path"]
)
)
def read_config(self): def read_config(self):
"""Load the <path>/config.yaml, <path>/*.yaml as stacks, sub-folders are sub-groups """ """Load the <path>/config.yaml, <path>/*.yaml as stacks, sub-folders are sub-groups"""
# Read top level config.yaml and extract CloudBender CTX # Read top level config.yaml and extract CloudBender CTX
_config = read_config_file(self.ctx['config_path'].joinpath('config.yaml')) _config = read_config_file(self.ctx["config_path"].joinpath("config.yaml"))
# Legacy naming # Legacy naming
if _config and _config.get('CloudBender'): if _config and _config.get("CloudBender"):
self.ctx.update(_config.get('CloudBender')) self.ctx.update(_config.get("CloudBender"))
if _config and _config.get('cloudbender'): if _config and _config.get("cloudbender"):
self.ctx.update(_config.get('cloudbender')) self.ctx.update(_config.get("cloudbender"))
# Make sure all paths are abs # Make sure all paths are abs
for k, v in self.ctx.items(): for k, v in self.ctx.items():
if k in ['config_path', 'template_path', 'hooks_path', 'docs_path', 'artifact_paths', 'outputs_path']: if k in [
"config_path",
"template_path",
"hooks_path",
"docs_path",
"artifact_paths",
"outputs_path",
]:
if isinstance(v, list): if isinstance(v, list):
new_list = [] new_list = []
for p in v: for p in v:
@ -56,7 +68,7 @@ class CloudBender(object):
if not v.is_absolute(): if not v.is_absolute():
self.ctx[k] = self.root.joinpath(v) self.ctx[k] = self.root.joinpath(v)
self.sg = StackGroup(self.ctx['config_path'], self.ctx) self.sg = StackGroup(self.ctx["config_path"], self.ctx)
self.sg.read_config() self.sg.read_config()
self.all_stacks = self.sg.get_stacks() self.all_stacks = self.sg.get_stacks()
@ -77,15 +89,15 @@ class CloudBender(object):
token = token[7:] token = token[7:]
# If path ends with yaml we look for stacks # If path ends with yaml we look for stacks
if token.endswith('.yaml'): if token.endswith(".yaml"):
stacks = self.sg.get_stacks(token, match_by='path') stacks = self.sg.get_stacks(token, match_by="path")
# otherwise assume we look for a group, if we find a group return all stacks below # otherwise assume we look for a group, if we find a group return all stacks below
else: else:
# Strip potential trailing slash # Strip potential trailing slash
token = token.rstrip('/') token = token.rstrip("/")
sg = self.sg.get_stackgroup(token, match_by='path') sg = self.sg.get_stackgroup(token, match_by="path")
if sg: if sg:
stacks = sg.get_stacks() stacks = sg.get_stacks()

View File

@ -8,6 +8,7 @@ from functools import wraps
from .exceptions import InvalidHook from .exceptions import InvalidHook
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,7 +41,9 @@ def pulumi_ws(func):
@wraps(func) @wraps(func)
def decorated(self, *args, **kwargs): def decorated(self, *args, **kwargs):
# setup temp workspace # setup temp workspace
self.work_dir = tempfile.mkdtemp(dir=tempfile.gettempdir(), prefix="cloudbender-") self.work_dir = tempfile.mkdtemp(
dir=tempfile.gettempdir(), prefix="cloudbender-"
)
response = func(self, *args, **kwargs) response = func(self, *args, **kwargs)
@ -63,4 +66,4 @@ def cmd(stack, arguments):
hook = subprocess.run(arguments, stdout=subprocess.PIPE) hook = subprocess.run(arguments, stdout=subprocess.PIPE)
logger.info(hook.stdout.decode("utf-8")) logger.info(hook.stdout.decode("utf-8"))
except TypeError: except TypeError:
raise InvalidHook('Invalid argument {}'.format(arguments)) raise InvalidHook("Invalid argument {}".format(arguments))

View File

@ -24,10 +24,10 @@ logger = logging.getLogger(__name__)
@jinja2.contextfunction @jinja2.contextfunction
def option(context, attribute, default_value=u'', source='options'): def option(context, attribute, default_value="", source="options"):
""" Get attribute from options data structure, default_value otherwise """ """Get attribute from options data structure, default_value otherwise"""
environment = context.environment environment = context.environment
options = environment.globals['_config'][source] options = environment.globals["_config"][source]
if not attribute: if not attribute:
return default_value return default_value
@ -48,7 +48,7 @@ def option(context, attribute, default_value=u'', source='options'):
@jinja2.contextfunction @jinja2.contextfunction
def include_raw_gz(context, files=None, gz=True, remove_comments=False): def include_raw_gz(context, files=None, gz=True, remove_comments=False):
jenv = context.environment jenv = context.environment
output = '' output = ""
# For shell script we can even remove whitespaces so treat them individually # For shell script we can even remove whitespaces so treat them individually
# sed -e '2,$ {/^ *$/d ; /^ *#/d ; /^[ \t] *#/d ; /*^/d ; s/^[ \t]*// ; s/*[ \t]$// ; s/ $//}' # sed -e '2,$ {/^ *$/d ; /^ *#/d ; /^[ \t] *#/d ; /*^/d ; s/^[ \t]*// ; s/*[ \t]$// ; s/ $//}'
@ -57,33 +57,35 @@ def include_raw_gz(context, files=None, gz=True, remove_comments=False):
if remove_comments: if remove_comments:
# Remove full line comments but not shebang # Remove full line comments but not shebang
_re_comment = re.compile(r'^\s*#[^!]') _re_comment = re.compile(r"^\s*#[^!]")
_re_blank = re.compile(r'^\s*$') _re_blank = re.compile(r"^\s*$")
_re_keep = re.compile(r'^## template: jinja$') _re_keep = re.compile(r"^## template: jinja$")
stripped_output = '' stripped_output = ""
for curline in output.splitlines(): for curline in output.splitlines():
if re.match(_re_blank, curline): if re.match(_re_blank, curline):
continue continue
elif re.match(_re_keep, curline): elif re.match(_re_keep, curline):
stripped_output = stripped_output + curline + '\n' stripped_output = stripped_output + curline + "\n"
elif re.match(_re_comment, curline): elif re.match(_re_comment, curline):
logger.debug("Removed {}".format(curline)) logger.debug("Removed {}".format(curline))
else: else:
stripped_output = stripped_output + curline + '\n' stripped_output = stripped_output + curline + "\n"
output = stripped_output output = stripped_output
if not gz: if not gz:
return(output) return output
buf = io.BytesIO() buf = io.BytesIO()
f = gzip.GzipFile(mode='w', fileobj=buf, mtime=0) f = gzip.GzipFile(mode="w", fileobj=buf, mtime=0)
f.write(output.encode()) f.write(output.encode())
f.close() f.close()
# MaxSize is 21847 # MaxSize is 21847
logger.info("Compressed user-data from {} to {}".format(len(output), len(buf.getvalue()))) logger.info(
return base64.b64encode(buf.getvalue()).decode('utf-8') "Compressed user-data from {} to {}".format(len(output), len(buf.getvalue()))
)
return base64.b64encode(buf.getvalue()).decode("utf-8")
@jinja2.contextfunction @jinja2.contextfunction
@ -92,33 +94,33 @@ def raise_helper(context, msg):
# Custom tests # Custom tests
def regex(value='', pattern='', ignorecase=False, match_type='search'): def regex(value="", pattern="", ignorecase=False, match_type="search"):
''' Expose `re` as a boolean filter using the `search` method by default. """Expose `re` as a boolean filter using the `search` method by default.
This is likely only useful for `search` and `match` which already This is likely only useful for `search` and `match` which already
have their own filters. have their own filters.
''' """
if ignorecase: if ignorecase:
flags = re.I flags = re.I
else: else:
flags = 0 flags = 0
_re = re.compile(pattern, flags=flags) _re = re.compile(pattern, flags=flags)
if getattr(_re, match_type, 'search')(value) is not None: if getattr(_re, match_type, "search")(value) is not None:
return True return True
return False return False
def match(value, pattern='', ignorecase=False): def match(value, pattern="", ignorecase=False):
''' Perform a `re.match` returning a boolean ''' """Perform a `re.match` returning a boolean"""
return regex(value, pattern, ignorecase, 'match') return regex(value, pattern, ignorecase, "match")
def search(value, pattern='', ignorecase=False): def search(value, pattern="", ignorecase=False):
''' Perform a `re.search` returning a boolean ''' """Perform a `re.search` returning a boolean"""
return regex(value, pattern, ignorecase, 'search') return regex(value, pattern, ignorecase, "search")
# Custom filters # Custom filters
def sub(value='', pattern='', replace='', ignorecase=False): def sub(value="", pattern="", replace="", ignorecase=False):
if ignorecase: if ignorecase:
flags = re.I flags = re.I
else: else:
@ -129,9 +131,16 @@ def sub(value='', pattern='', replace='', ignorecase=False):
def pyminify(source, obfuscate=False, minify=True): def pyminify(source, obfuscate=False, minify=True):
# pyminifier options # pyminifier options
options = types.SimpleNamespace( options = types.SimpleNamespace(
tabs=False, replacement_length=1, use_nonlatin=0, tabs=False,
obfuscate=0, obf_variables=1, obf_classes=0, obf_functions=0, replacement_length=1,
obf_import_methods=0, obf_builtins=0) use_nonlatin=0,
obfuscate=0,
obf_variables=1,
obf_classes=0,
obf_functions=0,
obf_import_methods=0,
obf_builtins=0,
)
tokens = pyminifier.token_utils.listified_tokenizer(source) tokens = pyminifier.token_utils.listified_tokenizer(source)
@ -141,13 +150,17 @@ def pyminify(source, obfuscate=False, minify=True):
if obfuscate: if obfuscate:
name_generator = pyminifier.obfuscate.obfuscation_machine(use_unicode=False) name_generator = pyminifier.obfuscate.obfuscation_machine(use_unicode=False)
pyminifier.obfuscate.obfuscate("__main__", tokens, options, name_generator=name_generator) pyminifier.obfuscate.obfuscate(
"__main__", tokens, options, name_generator=name_generator
)
# source = pyminifier.obfuscate.apply_obfuscation(source) # source = pyminifier.obfuscate.apply_obfuscation(source)
source = pyminifier.token_utils.untokenize(tokens) source = pyminifier.token_utils.untokenize(tokens)
# logger.info(source) # logger.info(source)
minified_source = pyminifier.compression.gz_pack(source) minified_source = pyminifier.compression.gz_pack(source)
logger.info("Compressed python code from {} to {}".format(len(source), len(minified_source))) logger.info(
"Compressed python code from {} to {}".format(len(source), len(minified_source))
)
return minified_source return minified_source
@ -157,10 +170,12 @@ def inline_yaml(block):
def JinjaEnv(template_locations=[]): def JinjaEnv(template_locations=[]):
LoggingUndefined = jinja2.make_logging_undefined(logger=logger, base=Undefined) LoggingUndefined = jinja2.make_logging_undefined(logger=logger, base=Undefined)
jenv = jinja2.Environment(trim_blocks=True, jenv = jinja2.Environment(
lstrip_blocks=True, trim_blocks=True,
undefined=LoggingUndefined, lstrip_blocks=True,
extensions=['jinja2.ext.loopcontrols', 'jinja2.ext.do']) undefined=LoggingUndefined,
extensions=["jinja2.ext.loopcontrols", "jinja2.ext.do"],
)
if template_locations: if template_locations:
jinja_loaders = [] jinja_loaders = []
@ -171,29 +186,29 @@ def JinjaEnv(template_locations=[]):
else: else:
jenv.loader = jinja2.BaseLoader() jenv.loader = jinja2.BaseLoader()
jenv.globals['include_raw'] = include_raw_gz jenv.globals["include_raw"] = include_raw_gz
jenv.globals['raise'] = raise_helper jenv.globals["raise"] = raise_helper
jenv.globals['option'] = option jenv.globals["option"] = option
jenv.filters['sub'] = sub jenv.filters["sub"] = sub
jenv.filters['pyminify'] = pyminify jenv.filters["pyminify"] = pyminify
jenv.filters['inline_yaml'] = inline_yaml jenv.filters["inline_yaml"] = inline_yaml
jenv.tests['match'] = match jenv.tests["match"] = match
jenv.tests['regex'] = regex jenv.tests["regex"] = regex
jenv.tests['search'] = search jenv.tests["search"] = search
return jenv return jenv
def read_config_file(path, variables={}): def read_config_file(path, variables={}):
""" reads yaml config file, passes it through jinja and returns data structre """reads yaml config file, passes it through jinja and returns data structre
- OS ENV are available as {{ ENV.<VAR> }} - OS ENV are available as {{ ENV.<VAR> }}
- variables defined in parent configs are available as {{ <VAR> }} - variables defined in parent configs are available as {{ <VAR> }}
""" """
jinja_variables = copy.deepcopy(variables) jinja_variables = copy.deepcopy(variables)
jinja_variables['ENV'] = os.environ jinja_variables["ENV"] = os.environ
if path.exists(): if path.exists():
logger.debug("Reading config file: {}".format(path)) logger.debug("Reading config file: {}".format(path))
@ -205,7 +220,8 @@ def read_config_file(path, variables={}):
auto_reload=False, auto_reload=False,
loader=jinja2.FunctionLoader(_sops_loader), loader=jinja2.FunctionLoader(_sops_loader),
undefined=jinja2.StrictUndefined, undefined=jinja2.StrictUndefined,
extensions=['jinja2.ext.loopcontrols']) extensions=["jinja2.ext.loopcontrols"],
)
template = jenv.get_template(str(path)) template = jenv.get_template(str(path))
rendered_template = template.render(jinja_variables) rendered_template = template.render(jinja_variables)
data = yaml.safe_load(rendered_template) data = yaml.safe_load(rendered_template)
@ -220,26 +236,37 @@ def read_config_file(path, variables={}):
def _sops_loader(path): def _sops_loader(path):
""" Tries to loads yaml file """Tries to loads yaml file
If "sops" key is detected the file is piped through sops before returned If "sops" key is detected the file is piped through sops before returned
""" """
with open(path, 'r') as f: with open(path, "r") as f:
config_raw = f.read() config_raw = f.read()
data = yaml.safe_load(config_raw) data = yaml.safe_load(config_raw)
if data and 'sops' in data and 'DISABLE_SOPS' not in os.environ: if data and "sops" in data and "DISABLE_SOPS" not in os.environ:
try: try:
result = subprocess.run([ result = subprocess.run(
'sops', [
'--input-type', 'yaml', "sops",
'--output-type', 'yaml', "--input-type",
'--decrypt', '/dev/stdin' "yaml",
], stdout=subprocess.PIPE, input=config_raw.encode('utf-8'), "--output-type",
env=dict(os.environ, **{"AWS_SDK_LOAD_CONFIG": "1"})) "yaml",
"--decrypt",
"/dev/stdin",
],
stdout=subprocess.PIPE,
input=config_raw.encode("utf-8"),
env=dict(os.environ, **{"AWS_SDK_LOAD_CONFIG": "1"}),
)
except FileNotFoundError: except FileNotFoundError:
logger.exception("SOPS encrypted config {}, but unable to find sops binary! Try eg: https://github.com/mozilla/sops/releases/download/v3.5.0/sops-v3.5.0.linux".format(path)) logger.exception(
"SOPS encrypted config {}, but unable to find sops binary! Try eg: https://github.com/mozilla/sops/releases/download/v3.5.0/sops-v3.5.0.linux".format(
path
)
)
sys.exit(1) sys.exit(1)
return result.stdout.decode('utf-8') return result.stdout.decode("utf-8")
else: else:
return config_raw return config_raw

View File

@ -7,30 +7,38 @@ import pkg_resources
import pulumi import pulumi
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def pulumi_init(stack): def pulumi_init(stack):
# Fail early if pulumi binaries are not available # Fail early if pulumi binaries are not available
if not shutil.which('pulumi'): if not shutil.which("pulumi"):
raise FileNotFoundError("Cannot find pulumi binary, see https://www.pulumi.com/docs/get-started/install/") raise FileNotFoundError(
"Cannot find pulumi binary, see https://www.pulumi.com/docs/get-started/install/"
)
# add all artifact_paths/pulumi to the search path for easier imports in the pulumi code # add all artifact_paths/pulumi to the search path for easier imports in the pulumi code
for artifacts_path in stack.ctx['artifact_paths']: for artifacts_path in stack.ctx["artifact_paths"]:
_path = '{}/pulumi'.format(artifacts_path.resolve()) _path = "{}/pulumi".format(artifacts_path.resolve())
sys.path.append(_path) sys.path.append(_path)
# Try local implementation first, similar to Jinja2 mode # Try local implementation first, similar to Jinja2 mode
_found = False _found = False
try: try:
_stack = importlib.import_module('config.{}.{}'.format(stack.rel_path, stack.template).replace('/', '.')) _stack = importlib.import_module(
"config.{}.{}".format(stack.rel_path, stack.template).replace("/", ".")
)
_found = True _found = True
except ImportError: except ImportError:
for artifacts_path in stack.ctx['artifact_paths']: for artifacts_path in stack.ctx["artifact_paths"]:
try: try:
spec = importlib.util.spec_from_file_location("_stack", '{}/pulumi/{}.py'.format(artifacts_path.resolve(), stack.template)) spec = importlib.util.spec_from_file_location(
"_stack",
"{}/pulumi/{}.py".format(artifacts_path.resolve(), stack.template),
)
_stack = importlib.util.module_from_spec(spec) _stack = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_stack) spec.loader.exec_module(_stack)
_found = True _found = True
@ -39,36 +47,61 @@ def pulumi_init(stack):
pass pass
if not _found: if not _found:
raise FileNotFoundError("Cannot find Pulumi implementation for {}".format(stack.stackname)) raise FileNotFoundError(
"Cannot find Pulumi implementation for {}".format(stack.stackname)
)
project_name = stack.parameters['Conglomerate'] project_name = stack.parameters["Conglomerate"]
# Remove stacknameprefix if equals Conglomerate as Pulumi implicitly prefixes project_name # Remove stacknameprefix if equals Conglomerate as Pulumi implicitly prefixes project_name
pulumi_stackname = re.sub(r'^' + project_name + '-?', '', stack.stackname) pulumi_stackname = re.sub(r"^" + project_name + "-?", "", stack.stackname)
try: try:
pulumi_backend = '{}/{}/{}'.format(stack.pulumi['backend'], project_name, stack.region) pulumi_backend = "{}/{}/{}".format(
stack.pulumi["backend"], project_name, stack.region
)
except KeyError: except KeyError:
raise KeyError('Missing pulumi.backend setting !') raise KeyError("Missing pulumi.backend setting !")
account_id = stack.connection_manager.call('sts', 'get_caller_identity', profile=stack.profile, region=stack.region)['Account'] account_id = stack.connection_manager.call(
"sts", "get_caller_identity", profile=stack.profile, region=stack.region
)["Account"]
# Ugly hack as Pulumi currently doesnt support MFA_TOKENs during role assumptions # Ugly hack as Pulumi currently doesnt support MFA_TOKENs during role assumptions
# Do NOT set them via 'aws:secretKey' as they end up in the stack.json in plain text !!! # Do NOT set them via 'aws:secretKey' as they end up in the stack.json in plain text !!!
if stack.connection_manager._sessions[(stack.profile, stack.region)].get_credentials().token: if (
os.environ['AWS_SESSION_TOKEN'] = stack.connection_manager._sessions[(stack.profile, stack.region)].get_credentials().token stack.connection_manager._sessions[(stack.profile, stack.region)]
.get_credentials()
.token
):
os.environ["AWS_SESSION_TOKEN"] = (
stack.connection_manager._sessions[(stack.profile, stack.region)]
.get_credentials()
.token
)
os.environ['AWS_ACCESS_KEY_ID'] = stack.connection_manager._sessions[(stack.profile, stack.region)].get_credentials().access_key os.environ["AWS_ACCESS_KEY_ID"] = (
os.environ['AWS_SECRET_ACCESS_KEY'] = stack.connection_manager._sessions[(stack.profile, stack.region)].get_credentials().secret_key stack.connection_manager._sessions[(stack.profile, stack.region)]
os.environ['AWS_DEFAULT_REGION'] = stack.region .get_credentials()
.access_key
)
os.environ["AWS_SECRET_ACCESS_KEY"] = (
stack.connection_manager._sessions[(stack.profile, stack.region)]
.get_credentials()
.secret_key
)
os.environ["AWS_DEFAULT_REGION"] = stack.region
# Secrets provider # Secrets provider
try: try:
secrets_provider = stack.pulumi['secretsProvider'] secrets_provider = stack.pulumi["secretsProvider"]
if secrets_provider == 'passphrase' and 'PULUMI_CONFIG_PASSPHRASE' not in os.environ: if (
raise ValueError('Missing PULUMI_CONFIG_PASSPHRASE environment variable!') secrets_provider == "passphrase"
and "PULUMI_CONFIG_PASSPHRASE" not in os.environ
):
raise ValueError("Missing PULUMI_CONFIG_PASSPHRASE environment variable!")
except KeyError: except KeyError:
logger.warning('Missing pulumi.secretsProvider setting, secrets disabled !') logger.warning("Missing pulumi.secretsProvider setting, secrets disabled !")
secrets_provider = None secrets_provider = None
# Set tag for stack file name and version # Set tag for stack file name and version
@ -76,9 +109,11 @@ def pulumi_init(stack):
try: try:
_version = _stack.VERSION _version = _stack.VERSION
except AttributeError: except AttributeError:
_version = 'undefined' _version = "undefined"
_tags['zero-downtime.net/cloudbender'] = '{}:{}'.format(os.path.basename(_stack.__file__), _version) _tags["zero-downtime.net/cloudbender"] = "{}:{}".format(
os.path.basename(_stack.__file__), _version
)
_config = { _config = {
"aws:region": stack.region, "aws:region": stack.region,
@ -90,27 +125,36 @@ def pulumi_init(stack):
# inject all parameters as config in the <Conglomerate> namespace # inject all parameters as config in the <Conglomerate> namespace
for p in stack.parameters: for p in stack.parameters:
_config['{}:{}'.format(stack.parameters['Conglomerate'], p)] = stack.parameters[p] _config["{}:{}".format(stack.parameters["Conglomerate"], p)] = stack.parameters[
p
]
stack_settings = pulumi.automation.StackSettings( stack_settings = pulumi.automation.StackSettings(
config=_config, config=_config,
secrets_provider=secrets_provider, secrets_provider=secrets_provider,
encryption_salt=stack.pulumi.get('encryptionsalt', None), encryption_salt=stack.pulumi.get("encryptionsalt", None),
encrypted_key=stack.pulumi.get('encryptedkey', None) encrypted_key=stack.pulumi.get("encryptedkey", None),
) )
project_settings = pulumi.automation.ProjectSettings( project_settings = pulumi.automation.ProjectSettings(
name=project_name, name=project_name, runtime="python", backend={"url": pulumi_backend}
runtime="python", )
backend={"url": pulumi_backend})
ws_opts = pulumi.automation.LocalWorkspaceOptions( ws_opts = pulumi.automation.LocalWorkspaceOptions(
work_dir=stack.work_dir, work_dir=stack.work_dir,
project_settings=project_settings, project_settings=project_settings,
stack_settings={pulumi_stackname: stack_settings}, stack_settings={pulumi_stackname: stack_settings},
secrets_provider=secrets_provider) secrets_provider=secrets_provider,
)
stack = pulumi.automation.create_or_select_stack(stack_name=pulumi_stackname, project_name=project_name, program=_stack.pulumi_program, opts=ws_opts) stack = pulumi.automation.create_or_select_stack(
stack.workspace.install_plugin("aws", pkg_resources.get_distribution("pulumi_aws").version) stack_name=pulumi_stackname,
project_name=project_name,
program=_stack.pulumi_program,
opts=ws_opts,
)
stack.workspace.install_plugin(
"aws", pkg_resources.get_distribution("pulumi_aws").version
)
return stack return stack

File diff suppressed because it is too large Load Diff

View File

@ -13,19 +13,21 @@ class StackGroup(object):
self.name = None self.name = None
self.ctx = ctx self.ctx = ctx
self.path = path self.path = path
self.rel_path = path.relative_to(ctx['config_path']) self.rel_path = path.relative_to(ctx["config_path"])
self.config = {} self.config = {}
self.sgs = [] self.sgs = []
self.stacks = [] self.stacks = []
if self.rel_path == '.': if self.rel_path == ".":
self.rel_path = '' self.rel_path = ""
def dump_config(self): def dump_config(self):
for sg in self.sgs: for sg in self.sgs:
sg.dump_config() sg.dump_config()
logger.debug("StackGroup {}: {}".format(self.rel_path, pprint.pformat(self.config))) logger.debug(
"StackGroup {}: {}".format(self.rel_path, pprint.pformat(self.config))
)
for s in self.stacks: for s in self.stacks:
s.dump_config() s.dump_config()
@ -35,7 +37,9 @@ class StackGroup(object):
return None return None
# First read config.yaml if present # First read config.yaml if present
_config = read_config_file(self.path.joinpath('config.yaml'), parent_config.get('variables', {})) _config = read_config_file(
self.path.joinpath("config.yaml"), parent_config.get("variables", {})
)
# Stack Group name if not explicit via config is derived from subfolder, or in case of root object the parent folder # Stack Group name if not explicit via config is derived from subfolder, or in case of root object the parent folder
if "stackgroupname" in _config: if "stackgroupname" in _config:
@ -45,19 +49,25 @@ class StackGroup(object):
# Merge config with parent config # Merge config with parent config
self.config = dict_merge(parent_config, _config) self.config = dict_merge(parent_config, _config)
stackname_prefix = self.config.get('stacknameprefix', '') stackname_prefix = self.config.get("stacknameprefix", "")
logger.debug("StackGroup {} added.".format(self.name)) logger.debug("StackGroup {} added.".format(self.name))
# Add stacks # Add stacks
stacks = [s for s in self.path.glob('*.yaml') if not s.name == "config.yaml"] stacks = [s for s in self.path.glob("*.yaml") if not s.name == "config.yaml"]
for stack_path in stacks: for stack_path in stacks:
stackname = stack_path.name.split('.')[0] stackname = stack_path.name.split(".")[0]
template = stackname template = stackname
if stackname_prefix: if stackname_prefix:
stackname = stackname_prefix + stackname stackname = stackname_prefix + stackname
new_stack = Stack(name=stackname, template=template, path=stack_path, rel_path=str(self.rel_path), ctx=self.ctx) new_stack = Stack(
name=stackname,
template=template,
path=stack_path,
rel_path=str(self.rel_path),
ctx=self.ctx,
)
new_stack.read_config(self.config) new_stack.read_config(self.config)
self.stacks.append(new_stack) self.stacks.append(new_stack)
@ -68,22 +78,24 @@ class StackGroup(object):
self.sgs.append(sg) self.sgs.append(sg)
def get_stacks(self, name=None, recursive=True, match_by='name'): def get_stacks(self, name=None, recursive=True, match_by="name"):
""" Returns [stack] matching stack_name or [all] """ """Returns [stack] matching stack_name or [all]"""
stacks = [] stacks = []
if name: if name:
logger.debug("Looking for stack {} in group {}".format(name, self.name)) logger.debug("Looking for stack {} in group {}".format(name, self.name))
for s in self.stacks: for s in self.stacks:
if name: if name:
if match_by == 'name' and s.stackname != name: if match_by == "name" and s.stackname != name:
continue continue
if match_by == 'path' and not s.path.match(name): if match_by == "path" and not s.path.match(name):
continue continue
if self.rel_path: if self.rel_path:
logger.debug("Found stack {} in group {}".format(s.stackname, self.rel_path)) logger.debug(
"Found stack {} in group {}".format(s.stackname, self.rel_path)
)
else: else:
logger.debug("Found stack {}".format(s.stackname)) logger.debug("Found stack {}".format(s.stackname))
stacks.append(s) stacks.append(s)
@ -96,14 +108,20 @@ class StackGroup(object):
return stacks return stacks
def get_stackgroup(self, name=None, recursive=True, match_by='name'): def get_stackgroup(self, name=None, recursive=True, match_by="name"):
""" Returns stack group matching stackgroup_name or all if None """ """Returns stack group matching stackgroup_name or all if None"""
if not name or (self.name == name and match_by == 'name') or (self.path.match(name) and match_by == 'path'): if (
not name
or (self.name == name and match_by == "name")
or (self.path.match(name) and match_by == "path")
):
logger.debug("Found stack_group {}".format(self.name)) logger.debug("Found stack_group {}".format(self.name))
return self return self
if name and self.name != 'config': if name and self.name != "config":
logger.debug("Looking for stack_group {} in group {}".format(name, self.name)) logger.debug(
"Looking for stack_group {} in group {}".format(name, self.name)
)
if recursive: if recursive:
for sg in self.sgs: for sg in self.sgs:

View File

@ -5,7 +5,7 @@ import re
def dict_merge(a, b): def dict_merge(a, b):
""" Deep merge to allow proper inheritance for config files""" """Deep merge to allow proper inheritance for config files"""
if not a: if not a:
return b return b
@ -36,16 +36,14 @@ def setup_logging(debug):
logging.getLogger("botocore").setLevel(logging.INFO) logging.getLogger("botocore").setLevel(logging.INFO)
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="[%(asctime)s] %(name)s %(message)s", fmt="[%(asctime)s] %(name)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S"
) )
else: else:
our_level = logging.INFO our_level = logging.INFO
logging.getLogger("botocore").setLevel(logging.CRITICAL) logging.getLogger("botocore").setLevel(logging.CRITICAL)
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="[%(asctime)s] %(message)s", fmt="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S"
) )
log_handler = logging.StreamHandler() log_handler = logging.StreamHandler()
@ -57,8 +55,8 @@ def setup_logging(debug):
def search_refs(template, attributes, mode): def search_refs(template, attributes, mode):
""" Traverses a template and searches for any remote references and """Traverses a template and searches for any remote references and
adds them to the attributes set adds them to the attributes set
""" """
if isinstance(template, dict): if isinstance(template, dict):
for k, v in template.items(): for k, v in template.items():
@ -70,7 +68,7 @@ def search_refs(template, attributes, mode):
# CloudBender::StackRef # CloudBender::StackRef
if k == "CloudBender::StackRef": if k == "CloudBender::StackRef":
try: try:
attributes.append(v['StackTags']['Artifact']) attributes.append(v["StackTags"]["Artifact"])
except KeyError: except KeyError:
pass pass
@ -91,11 +89,11 @@ def get_s3_url(url, *args):
bucket = None bucket = None
path = None path = None
m = re.match('^(s3://)?([^/]*)(/.*)?', url) m = re.match("^(s3://)?([^/]*)(/.*)?", url)
bucket = m[2] bucket = m[2]
if m[3]: if m[3]:
path = m[3].lstrip('/') path = m[3].lstrip("/")
path = os.path.join(path, *args) path = os.path.join(path, *args)
return(bucket, path) return (bucket, path)