Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
282 changes: 282 additions & 0 deletions data-tool/flows/auth/auth_affiliation_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
import math
import os
from typing import Dict, List

from prefect import flow
from prefect.context import get_run_context
from prefect.futures import wait
from prefect.states import Failed
from prefect.task_runners import ConcurrentTaskRunner
from sqlalchemy import text

from common.extract_tracking_service import ExtractTrackingService, ProcessingStatuses
from common.init_utils import colin_extract_init, get_config
from common.query_utils import convert_result_set_to_dict

from .auth_models import AuthCreatePlan, AuthDeletePlan, AuthSelectionMode
from .auth_queries import (
get_auth_business_profiles_query,
get_auth_reservable_corps_query,
get_auth_reservable_count_query,
)
from .auth_tasks import get_auth_token, parse_accounts_csv, perform_auth_create_for_corp, perform_auth_delete_for_corp

FLOW_NAME = 'auth-affiliation-flow'


def _get_max_workers() -> int:
try:
v = int(os.getenv('AUTH_MAX_WORKERS', '50'))
return v if v > 0 else 50
except Exception:
return 50


def _parse_selection_mode(config) -> AuthSelectionMode:
raw = (getattr(config, 'AUTH_SELECTION_MODE', 'MIGRATION_FILTER') or 'MIGRATION_FILTER').strip().upper()
try:
return AuthSelectionMode(raw)
except Exception as e:
raise ValueError(f'Unknown AUTH_SELECTION_MODE: {raw}') from e


def _fetch_profiles(colin_engine, corp_nums: List[str], suffix: str) -> Dict[str, dict]:
if not corp_nums:
return {}
sql = get_auth_business_profiles_query(corp_nums, suffix or '')
with colin_engine.connect() as conn:
rs = conn.execute(text(sql))
rows = convert_result_set_to_dict(rs)
return {r['identifier']: r for r in rows}


@flow(
name='Auth-Affiliation-Flow',
log_prints=True,
persist_result=False,
task_runner=ConcurrentTaskRunner(max_workers=_get_max_workers())
)
def auth_affiliation_flow():
"""
Create OR delete affiliations (mutually exclusive for this run).

- Create mode: uses AuthCreatePlan(create_affiliations=True)
- Delete mode: uses AuthDeletePlan(delete_affiliations=True)

Selection excludes any corp already tracked in auth_processing for (corp_num, FLOW_NAME, environment).
"""
config = get_config()
colin_engine = colin_extract_init(config)
selection_mode = _parse_selection_mode(config)

do_create = bool(getattr(config, 'AUTH_CREATE_AFFILIATIONS', False))
do_delete = bool(getattr(config, 'AUTH_DELETE_AFFILIATIONS', False))

if do_create and do_delete:
raise ValueError('Invalid config: cannot both AUTH_CREATE_AFFILIATIONS and AUTH_DELETE_AFFILIATIONS in one run')
if not do_create and not do_delete:
raise ValueError('Nothing to do: set either AUTH_CREATE_AFFILIATIONS or AUTH_DELETE_AFFILIATIONS')

create_plan = None
delete_plan = None
if do_create:
create_plan = AuthCreatePlan(
create_entity=bool(getattr(config, 'AUTH_CREATE_ENTITY', True)),
upsert_contact=False,
create_affiliations=True,
send_unaffiliated_invite=False,
fail_if_missing_email=False,
dry_run=bool(getattr(config, 'AUTH_DRY_RUN', False)),
)
plan_desc = create_plan
else:
delete_plan = AuthDeletePlan(
delete_affiliations=True,
delete_entity=False,
delete_invites=False,
dry_run=bool(getattr(config, 'AUTH_DRY_RUN', False)),
)
plan_desc = delete_plan

# Count reservable
count_sql = get_auth_reservable_count_query(
flow_name=FLOW_NAME,
config=config,
selection_mode=selection_mode
)
with colin_engine.connect() as conn:
total_reservable = int(conn.execute(text(count_sql)).scalar() or 0)

if total_reservable <= 0:
print('No reservable corps found for this run.')
return

if getattr(config, 'AUTH_BATCHES', 0) <= 0:
raise ValueError('AUTH_BATCHES must be explicitly set to a positive integer')
if getattr(config, 'AUTH_BATCH_SIZE', 0) <= 0:
raise ValueError('AUTH_BATCH_SIZE must be explicitly set to a positive integer')

batch_size = config.AUTH_BATCH_SIZE
max_corps = min(total_reservable, config.AUTH_BATCHES * config.AUTH_BATCH_SIZE)

flow_run_id = get_run_context().flow_run.id

tracking = ExtractTrackingService(
config.DATA_LOAD_ENV,
colin_engine,
FLOW_NAME,
table_name='auth_processing',
statement_timeout_ms=getattr(config, 'RESERVE_STATEMENT_TIMEOUT_MS', None)
)

extra_insert_cols = ['account_ids']

base_query = get_auth_reservable_corps_query(
flow_name=FLOW_NAME,
config=config,
batch_size=max_corps,
selection_mode=selection_mode,
include_account_ids=True,
include_contact_email=False
)

reserved = tracking.reserve_for_flow(
base_query=base_query,
flow_run_id=flow_run_id,
extra_insert_cols=extra_insert_cols,
fallback_account_ids=config.AFFILIATE_ENTITY_ACCOUNT_IDS_CSV
)

if reserved <= 0:
print('No corps reserved (cohort may be exhausted or already reserved).')
return

batches = min(math.ceil(reserved / batch_size), config.AUTH_BATCHES)

print(f'👷 Auth affiliation mode: {"CREATE" if do_create else "DELETE"}')
print(f'👷 Plan: {plan_desc}')
print(f'👷 Reservable={total_reservable}, Reserved={reserved}, Batches={batches}, BatchSize={batch_size}')
print(f'👷 SelectionMode={selection_mode.value}')

cnt = 0
total_failed = 0
total_completed = 0

while cnt < batches:
claimed = tracking.claim_batch(
flow_run_id,
batch_size,
extra_return_cols=extra_insert_cols,
as_dict=True
)
if not claimed:
print('No more corps available to claim')
break

corp_nums = [r['corp_num'] for r in claimed]
corp_accounts = {r['corp_num']: (r.get('account_ids') or None) for r in claimed}

profiles = _fetch_profiles(colin_engine, corp_nums, getattr(config, 'CORP_NAME_SUFFIX', '') or '') if do_create else {}

try:
token = get_auth_token(config)
except Exception as e:
err = f'Failed to obtain auth token: {repr(e)}'
print(f'❌ {err}')
for corp_num in corp_nums:
tracking.update_corp_status(
flow_run_id,
corp_num,
ProcessingStatuses.FAILED,
error=err,
entity_action='FAILED' if (do_create and create_plan and create_plan.create_entity) else 'NOT_RUN',
contact_action='NOT_RUN',
affiliation_action='FAILED',
invite_action='NOT_RUN',
action_detail='token_error'
)
return Failed(message=err)

futures = []
for corp_num in corp_nums:
accounts = parse_accounts_csv(corp_accounts.get(corp_num))

if do_create:
profile = profiles.get(corp_num)
if not profile:
total_failed += 1
tracking.update_corp_status(
flow_run_id,
corp_num,
ProcessingStatuses.FAILED,
error='Missing business profile for corp in COLIN extract',
entity_action='FAILED' if (create_plan and create_plan.create_entity) else 'NOT_RUN',
contact_action='NOT_RUN',
affiliation_action='FAILED',
invite_action='NOT_RUN',
action_detail='profile_missing'
)
continue

futures.append(
perform_auth_create_for_corp.submit(
config,
corp_num,
profile,
accounts,
create_plan,
token
)
)
else:
futures.append(
perform_auth_delete_for_corp.submit(
config,
corp_num,
accounts,
delete_plan,
token
)
)

wait(futures)

for f in futures:
res = f.result()
actions = [
res.get('entity_action'),
res.get('contact_action'),
res.get('affiliation_action'),
res.get('invite_action'),
]
failed = any(a == 'FAILED' for a in actions if a)
status = ProcessingStatuses.FAILED if failed else ProcessingStatuses.COMPLETED

tracking.update_corp_status(
flow_run_id,
res['corp_num'],
status,
error=res.get('error'),
entity_action=res.get('entity_action'),
contact_action=res.get('contact_action'),
affiliation_action=res.get('affiliation_action'),
invite_action=res.get('invite_action'),
action_detail=res.get('action_detail')
)

if status == ProcessingStatuses.FAILED:
total_failed += 1
else:
total_completed += 1

cnt += 1
print(f'🌟 Complete round {cnt}/{batches}. Completed={total_completed}, Failed={total_failed}')

if total_failed > 0:
return Failed(message=f'{total_failed} corps failed in {FLOW_NAME}.')

print(f'🌰 {FLOW_NAME} complete. Completed={total_completed}, Failed={total_failed}')


if __name__ == '__main__':
auth_affiliation_flow()
Loading