Skip to main content

PersonalizePipelineStack

/personalize_cdk/personalize_pipeline_stack.py

from aws_cdk import (
Stack,
Duration,
aws_iam as iam,
aws_lambda as lambda_,
aws_stepfunctions as sfn,
aws_stepfunctions_tasks as tasks,
aws_logs as logs,
)
from constructs import Construct

class PersonalizePipelineStack(Stack):

def __init__(self, scope: Construct, construct_id: str,
data_bucket, personalize_role, dataset_arn,
dataset_group_arn, **kwargs) -> None:
super().__init__(scope, construct_id, **kwargs)

# ===================
# Lambda Execution Role
# ===================
lambda_role = iam.Role(
self, "LambdaRole",
assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"),
managed_policies=[
iam.ManagedPolicy.from_aws_managed_policy_name(
"service-role/AWSLambdaBasicExecutionRole"
)
]
)

lambda_role.add_to_policy(
iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
"personalize:*",
"iam:PassRole"
],
resources=["*"]
)
)

# ===================
# Lambda Functions
# ===================

# Create Dataset Import Job
create_import_job_fn = lambda_.Function(
self, "CreateImportJobFunction",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="index.handler",
code=lambda_.Code.from_inline(self._get_import_job_code()),
role=lambda_role,
timeout=Duration.minutes(5),
environment={
"DATASET_ARN": dataset_arn,
"ROLE_ARN": personalize_role.role_arn,
"BUCKET_NAME": data_bucket.bucket_name
}
)

# Check Import Job Status
check_import_status_fn = lambda_.Function(
self, "CheckImportStatusFunction",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="index.handler",
code=lambda_.Code.from_inline(self._get_check_import_status_code()),
role=lambda_role,
timeout=Duration.minutes(1)
)

# Create Solution
create_solution_fn = lambda_.Function(
self, "CreateSolutionFunction",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="index.handler",
code=lambda_.Code.from_inline(self._get_create_solution_code()),
role=lambda_role,
timeout=Duration.minutes(5),
environment={
"DATASET_GROUP_ARN": dataset_group_arn
}
)

# Check Solution Status
check_solution_status_fn = lambda_.Function(
self, "CheckSolutionStatusFunction",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="index.handler",
code=lambda_.Code.from_inline(self._get_check_solution_status_code()),
role=lambda_role,
timeout=Duration.minutes(1)
)

# Create Campaign
create_campaign_fn = lambda_.Function(
self, "CreateCampaignFunction",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="index.handler",
code=lambda_.Code.from_inline(self._get_create_campaign_code()),
role=lambda_role,
timeout=Duration.minutes(5)
)

# ===================
# Step Functions Tasks
# ===================

# Task: Create Import Job
create_import_job_task = tasks.LambdaInvoke(
self, "CreateImportJob",
lambda_function=create_import_job_fn,
output_path="$.Payload"
)

# Task: Check Import Status
check_import_status_task = tasks.LambdaInvoke(
self, "CheckImportStatus",
lambda_function=check_import_status_fn,
output_path="$.Payload"
)

# Task: Create Solution
create_solution_task = tasks.LambdaInvoke(
self, "CreateSolution",
lambda_function=create_solution_fn,
output_path="$.Payload"
)

# Task: Check Solution Status
check_solution_status_task = tasks.LambdaInvoke(
self, "CheckSolutionStatus",
lambda_function=check_solution_status_fn,
output_path="$.Payload"
)

# Task: Create Campaign
create_campaign_task = tasks.LambdaInvoke(
self, "CreateCampaign",
lambda_function=create_campaign_fn,
output_path="$.Payload"
)

# Wait states for job import
wait_for_import = sfn.Wait(
self, "WaitForImport",
time=sfn.WaitTime.duration(Duration.minutes(2))
)

# Wait for model training
wait_for_solution = sfn.Wait(
self, "WaitForSolution",
time=sfn.WaitTime.duration(Duration.minutes(5))
)

# Choice states
is_import_complete = sfn.Choice(self, "IsImportComplete")
is_solution_complete = sfn.Choice(self, "IsSolutionComplete")

# Success/Fail states
success_state = sfn.Succeed(self, "PipelineSucceeded")
fail_state = sfn.Fail(self, "PipelineFailed", error="PipelineError")

# ===================
# Define Workflow
# ===================
definition = (
create_import_job_task
.next(wait_for_import)
.next(check_import_status_task)
.next(
is_import_complete
.when(
sfn.Condition.string_equals("$.status", "ACTIVE"),
create_solution_task
)
.when(
sfn.Condition.string_equals("$.status", "CREATE FAILED"),
fail_state
)
.otherwise(wait_for_import)
)
)

create_solution_task.next(wait_for_solution).next(
check_solution_status_task
).next(
is_solution_complete
.when(
sfn.Condition.string_equals("$.status", "ACTIVE"),
create_campaign_task.next(success_state)
)
.when(
sfn.Condition.string_equals("$.status", "CREATE FAILED"),
fail_state
)
.otherwise(wait_for_solution)
)

# ===================
# State Machine
# ===================
state_machine = sfn.StateMachine(
self, "PersonalizePipeline",
state_machine_name="personalize-training-pipeline",
definition=definition,
timeout=Duration.hours(3),
logs=sfn.LogOptions(
destination=logs.LogGroup(self, "PipelineLogs"),
level=sfn.LogLevel.ALL
)
)

# ===================
# Lambda Code (Inline)
# ===================

def _get_import_job_code(self):
return '''
import boto3
import os
import time

def handler(event, context):
personalize = boto3.client("personalize")

job_name = f"import-job-{int(time.time())}"

response = personalize.create_dataset_import_job(
jobName=job_name,
datasetArn=os.environ["DATASET_ARN"],
dataSource={
"dataLocation": f"s3://{os.environ['BUCKET_NAME']}/data/interactions.csv"
},
roleArn=os.environ["ROLE_ARN"]
)

return {
"importJobArn": response["datasetImportJobArn"],
"status": "CREATE PENDING"
}
'''

def _get_check_import_status_code(self):
return '''
import boto3

def handler(event, context):
personalize = boto3.client("personalize")

response = personalize.describe_dataset_import_job(
datasetImportJobArn=event["importJobArn"]
)

return {
"importJobArn": event["importJobArn"],
"status": response["datasetImportJob"]["status"]
}
'''

def _get_create_solution_code(self):
return '''
import boto3
import os
import time

def handler(event, context):
personalize = boto3.client("personalize")

solution_name = f"solution-{int(time.time())}"

# Create Solution
solution_response = personalize.create_solution(
name=solution_name,
datasetGroupArn=os.environ["DATASET_GROUP_ARN"],
recipeArn="arn:aws:personalize:::recipe/aws-user-personalization"
)

# Create Solution Version
version_response = personalize.create_solution_version(
solutionArn=solution_response["solutionArn"]
)

return {
"solutionArn": solution_response["solutionArn"],
"solutionVersionArn": version_response["solutionVersionArn"],
"status": "CREATE PENDING"
}
'''

def _get_check_solution_status_code(self):
return '''
import boto3

def handler(event, context):
personalize = boto3.client("personalize")

response = personalize.describe_solution_version(
solutionVersionArn=event["solutionVersionArn"]
)

return {
"solutionArn": event["solutionArn"],
"solutionVersionArn": event["solutionVersionArn"],
"status": response["solutionVersion"]["status"]
}
'''

def _get_create_campaign_code(self):
return '''
import boto3
import time

def handler(event, context):
personalize = boto3.client("personalize")

campaign_name = f"campaign-{int(time.time())}"

response = personalize.create_campaign(
name=campaign_name,
solutionVersionArn=event["solutionVersionArn"],
minProvisionedTPS=1
)

return {
"campaignArn": response["campaignArn"],
"solutionVersionArn": event["solutionVersionArn"],
"status": "CREATED"
}
'''