[go: nahoru, domu]

Skip to content

Commit

Permalink
Added aws-sso-util roles, improved assignments, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
benkehoe committed Nov 18, 2020
1 parent a4a707e commit e15a03d
Show file tree
Hide file tree
Showing 14 changed files with 498 additions and 185 deletions.
4 changes: 2 additions & 2 deletions src/aws_sso_util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '2.11.0'

from .sso import get_boto3_session, login
from .assignments import get_assignments
from .sso import get_boto3_session, login, list_available_roles
from .assignments import list_assignments
2 changes: 1 addition & 1 deletion src/aws_sso_util/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_accounts_for_ou(session, ou, recursive, refresh=False, cache=None):

paginator = client.get_paginator('list_accounts_for_parent')
for response in paginator.paginate(ParentId=ou):
ou_accounts.extend(data['Id'] for data in response['Accounts'])
ou_accounts.extend(response['Accounts'])

cache[ou_accounts_key] = ou_accounts

Expand Down
144 changes: 106 additions & 38 deletions src/aws_sso_util/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import numbers
import collections
import logging
from collections.abc import Iterable
import itertools

import aws_error_utils

from .api_utils import Ids
from .api_utils import Ids, get_accounts_for_ou

LOGGER = logging.getLogger(__name__)

Expand All @@ -32,9 +34,36 @@ def _filter(filter_cache, key, func, args):
filter_cache[key] = func(*args)
return filter_cache[key]

def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))

def _is_principal_tuple(principal):
try:
return all([
len(principal) == 2,
isinstance(principal[0], str),
principal[0] in ["GROUP", "USER"],
isinstance(principal[1], str),
])
except:
return False

def _process_principal(principal):
if not principal:
return None
if isinstance(principal, str):
return [("UNKNOWN", principal)]
if _is_principal_tuple(principal):
return [tuple(principal)]
else:
return _flatten(_process_principal(p) for p in principal)

def _process_permission_set(ids, permission_set):
if not permission_set:
return None
if not isinstance(permission_set, str) and isinstance(permission_set, Iterable):
return _flatten(_process_permission_set(ids, ps) for ps in permission_set)

if permission_set.startswith("arn"):
permission_set_arn = permission_set
elif permission_set.startswith("ssoins-") or permission_set.startswith("ins-"):
Expand All @@ -43,7 +72,18 @@ def _process_permission_set(ids, permission_set):
permission_set_arn = f"arn:aws:sso:::permissionSet/{ids.instance_id}/{permission_set}"
else:
raise TypeError(f"Invalid permission set id {permission_set}")
return permission_set_arn
return [permission_set_arn]

def _is_target_tuple(target):
try:
return all([
len(target) == 2,
isinstance(target[0], str),
target[0] in ["AWS_OU", "AWS_ACCOUNT"],
isinstance(target[1], str),
])
except:
return False

def _process_target(target):
if not target:
Expand All @@ -53,27 +93,46 @@ def _process_target(target):
if isinstance(target, str):
if re.match(r"^\d+$", target):
target = target.rjust(12, '0')
return "AWS_ACCOUNT", target
return [("AWS_ACCOUNT", target)]
elif re.match(r"^r-[a-z0-9]{4,32}$", target) or re.match(r"^ou-[a-z0-9]{4,32}-[a-z0-9]{8,32}$", target):
return [("AWS_OU", target)]
else:
return "AWS_OU", target
else:
raise TypeError(f"Invalid target {target}")
elif _is_target_tuple(target):
target_type, target_id = target
if target_type not in ["AWS_ACCOUNT", "AWS_OU"]:
raise TypeError(f"Invalid target type {target_type}")
return target_type, target_id
return [(target_type, target_id)]
else:
value = _flatten(_process_target(t) for t in target)
return value

def _get_account_iterator(context: _Context):
def _get_account_iterator(target, context: _Context):
def target_iterator():
value = (*context.target, "UNKNOWN")
value = (*target, "UNKNOWN")
if not _filter(context.filter_cache, value[1], context.target_filter, value):
LOGGER.debug(f"Single account is filtered: {value}")
LOGGER.debug(f"Account is filtered: {value}")
else:
LOGGER.debug(f"Visiting single account: {value}")
yield value
return target_iterator

def _get_ou_iterator(context: _Context):
raise NotImplementedError
def _get_ou_iterator(target, context: _Context):
def target_iterator():
value = (*target, "UNKNOWN")
accounts = get_accounts_for_ou(context.session, value[1], recursive=context.ou_recursive)
for account in accounts:
yield "AWS_ACCOUNT", account["Id"], account["Name"]
return target_iterator

def _get_single_target_iterator(target, context: _Context):
target_type = target[0]
if target_type == "AWS_ACCOUNT":
return _get_account_iterator(target, context)
elif target_type == "AWS_OU":
return _get_ou_iterator(target, context)
else:
raise TypeError(f"Invalid target type {target_type}")

def _get_all_accounts_iterator(context: _Context):
def target_iterator():
Expand All @@ -97,21 +156,16 @@ def target_iterator():

def _get_target_iterator(context: _Context):
if context.target:
target_type = context.target[0]
if target_type == "AWS_ACCOUNT":
LOGGER.debug(f"Iterating for single account")
return _get_account_iterator(context)
elif target_type == "AWS_OU":
LOGGER.debug(f"Iterating for single OU")
return _get_ou_iterator(context)
else:
raise TypeError(f"Invalid target type {target_type}")
iterables = [_get_single_target_iterator(t, context) for t in context.target]
def target_iterator():
return itertools.chain(*[it() for it in iterables])
return target_iterator
else:
LOGGER.debug(f"Iterating for all accounts")
return _get_all_accounts_iterator(context)

def _get_single_permission_set_iterator(context: _Context):
permission_set_arn = context.permission_set
def _get_single_permission_set_iterator(permission_set, context: _Context):
permission_set_arn = permission_set
permission_set_id = permission_set_arn.split("/")[-1]

def permission_set_iterator(target_type, target_id, target_name):
Expand Down Expand Up @@ -170,8 +224,10 @@ def permission_set_iterator(target_type, target_id, target_name):

def _get_permission_set_iterator(context: _Context):
if context.permission_set:
LOGGER.debug("Iterating for a single permission set")
return _get_single_permission_set_iterator(context)
iterables = [_get_single_permission_set_iterator(ps, context) for ps in context.permission_set]
def permission_set_iterator(target_type, target_id, target_name):
return itertools.chain(*[it(target_type, target_id, target_name) for it in iterables])
return permission_set_iterator
else:
LOGGER.debug("Iterating for all permission sets")
return _get_all_permission_sets_iterator(context)
Expand Down Expand Up @@ -199,13 +255,17 @@ def principal_iterator(
for assignment in response["AccountAssignments"]:
principal_type = assignment["PrincipalType"]
principal_id = assignment["PrincipalId"]
LOGGER.debug(f"Visiting principal {principal_type}:{principal_id}")

if context.principal:
if (context.principal[0] != principal_type or context.principal[1] != principal_id):
LOGGER.debug(f"Principal {principal_type}:{principal_id} does not match single principal")
continue
for principal in context.principal:
type_matches = (principal[0] == "UNKNOWN" or principal[0] != principal_type)
if type_matches and principal[1] == principal_id:
LOGGER.debug(f"Found principal {principal_type}:{principal_id}")
break
else:
LOGGER.debug(f"Found single principal {principal_type}:{principal_id}")
LOGGER.debug(f"Principal {principal_type}:{principal_id} does not match principals")
continue

principal_key = (principal_type, principal_id)
if not context.get_principal_names:
Expand Down Expand Up @@ -238,7 +298,7 @@ def principal_iterator(

if not _filter(context.filter_cache, principal_key, context.principal_filter, (principal_type, principal_id, principal_name)):
if context.principal:
LOGGER.debug(f"Single principal is filtered: {principal_type}:{principal_id}")
LOGGER.debug(f"Principal is filtered: {principal_type}:{principal_id}")
else:
LOGGER.debug(f"Principal is filtered: {principal_type}:{principal_id}")
continue
Expand All @@ -258,7 +318,7 @@ def principal_iterator(
"target_name",
])

def get_assignments(
def list_assignments(
session,
instance_arn=None,
identity_store_id=None,
Expand All @@ -271,9 +331,15 @@ def get_assignments(
get_principal_names=True,
get_permission_set_names=True,
ou_recursive=False):
"""Iterate over AWS SSO assignments.
Args:
session:
"""
ids = Ids(lambda: session, instance_arn, identity_store_id)
ids.suppress_print = True
return _get_assignments(
return _list_assignments(
session,
ids,
principal=principal,
Expand All @@ -287,7 +353,7 @@ def get_assignments(
ou_recursive=ou_recursive,
)

def _get_assignments(
def _list_assignments(
session,
ids,
principal=None,
Expand All @@ -301,6 +367,7 @@ def _get_assignments(
ou_recursive=False):
ids.suppress_print = True

principal = _process_principal(principal)
permission_set = _process_permission_set(ids, permission_set)
target = _process_target(target)

Expand Down Expand Up @@ -354,16 +421,17 @@ def _get_assignments(
import json

logging.basicConfig(level=logging.INFO)
LOGGER.setLevel(logging.DEBUG)

if len(sys.argv) > 1:
kwargs = json.loads(sys.argv[1])
else:
kwargs = {}
kwargs = {}
for v in sys.argv[1:]:
if hasattr(logging, v):
LOGGER.setLevel(getattr(logging, v))
else:
kwargs = json.loads(sys.argv[1])

try:
session = boto3.Session()
for value in get_assignments(session, **kwargs):
for value in list_assignments(session, **kwargs):
print(",".join(value))
except KeyboardInterrupt:
pass
5 changes: 3 additions & 2 deletions src/aws_sso_util/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class OpenBrowserHandler(object):
def __init__(self, outfile=None, open_browser=None, message=None):
if not outfile:
if outfile is None:
outfile = sys.stderr
self._outfile = outfile

Expand All @@ -54,7 +54,8 @@ def __call__(self, userCode, verificationUri,
userCode=userCode
)

print(message, file=self._outfile)
if self._outfile:
print(message, file=self._outfile)

disable_browser = os.environ.get('AWS_SSO_DISABLE_BROWSER', '').lower() in ['1', 'true']
if self._open_browser and not disable_browser:
Expand Down
3 changes: 2 additions & 1 deletion src/aws_sso_util/cfn/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ def process_template(template,

resource_collection_dict = {}
max_stack_resources = 0
ou_fetcher = lambda ou, recursive: [a["Id"] for a in api_utils.get_accounts_for_ou(session, ou, recursive, cache=ou_accounts_cache)]
for resource_name, config in configs.items():
resource_collection = resources.get_resources_from_config(
config,
ou_fetcher=lambda ou, recursive: api_utils.get_accounts_for_ou(session, ou, recursive, cache=ou_accounts_cache))
ou_fetcher=ou_fetcher)

max_stack_resources += templates.get_max_number_of_child_stacks(resource_collection.num_resources, max_resources_per_template=max_resources_per_template)

Expand Down
2 changes: 2 additions & 0 deletions src/aws_sso_util/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .logout import logout
from .lookup import lookup
from .populate_profiles import populate_profiles
from .roles import roles

@click.group(name="aws-sso-util")
@click.version_option(version=__version__, message='%(version)s')
Expand All @@ -39,6 +40,7 @@ def configure():

cli.add_command(login)
cli.add_command(logout)
cli.add_command(roles)

cli.add_command(lookup)
cli.add_command(assignments)
Expand Down
Loading

0 comments on commit e15a03d

Please sign in to comment.