AWS Step FunctionsとLambdaでディープラーニングの訓練を全自動化する
動機とやったことの概要
- スポットインスタンスで学習をして、無駄なくインスタンスを止めたい
- Step Functionsへの入力を変えるだけで様々な条件での学習を実行させたい
- 機械学習のコード自体にこのStep Functionsへの依存性は持たせなくて良い方針で作ったので、ディープラーニング以外のバッチ処理でも同じように使えるはず
詳細
Lambdaに付与する権限
たぶん以下くらいの権限がLambda実行時に必要。
AWSLambdaAMIExecutionRole AmazonS3FullAccess AmazonEC2SpotFleetRole AWSLambdaBasicExecutionRole AmazonSNSFullAccess EC2ReadOnly ("ec2:DescribeSpotInstanceRequests"リソースへのアクセスを追加)
Step Functionsの入力
{ "exec_name": "pix2pix-20161231", "repository_url": "https://github.com/mattya/chainer-pix2pix.git", "repository_name": "chainer-pix2pix", "data_dir": "/home/ubuntu/data", "output_dir": "/home/ubuntu/result", "data_get_command": "/home/ubuntu/.pyenv/shims/aws s3 cp s3://pix2pixfacade/ /home/ubuntu/data --recursive", "exec_command": "/home/ubuntu/.pyenv/shims/python /home/ubuntu/chainer-pix2pix/train_facade.py -g 0 -e 100 -i /home/ubuntu/data --out /home/ubuntu/result --snapshot_interval 10000" }
変数名 | 説明 |
---|---|
exec_name | この実行の名前。バケット名にもなるため、アンダースコアを使わずkebab-case推奨 |
repository_url | git cloneする対象のリポジトリURL |
repository_name | git cloneしたあと取得できるリポジトリ名 |
data_dir | データを格納するディレクトリ |
output_dir | 訓練結果等を格納するディレクトリ |
data_get_command | データを取得するなど、訓練開始前に実施する |
exec_command | 訓練実施コマンド |
やってることの中身
Step Functionの定義
{ "Comment" : "Machine learning execution with spot instance", "StartAt" : "CreateS3Bucket", "States" : { "CreateS3Bucket": { "Type" : "Task", "Resource" : "arn:aws:lambda:ap-northeast-1:999999999999:function:create_s3_bucket", "Next" : "RequestSpotInstance" }, "RequestSpotInstance": { "Type" : "Task", "Resource" : "arn:aws:lambda:ap-northeast-1:999999999999:function:request_spot_instance", "Next" : "WaitBidding" }, "WaitBidding": { "Type" : "Wait", "Seconds" : 30, "Next" : "CheckBiddingResult" }, "CheckBiddingResult": { "Type" : "Task", "Resource" : "arn:aws:lambda:ap-northeast-1:999999999999:function:check_bidding_result", "Next": "ChoiceBiddingResult" }, "ChoiceBiddingResult": { "Type" : "Choice", "Choices": [ { "Variable": "$.request_result", "BooleanEquals": true, "Next": "NotifyRequestSuccess" }, { "Variable": "$.request_result", "BooleanEquals": false, "Next": "NotifyRequestFailed" } ], "Default": "NotifyRequestFailed" }, "NotifyRequestFailed": { "Type" : "Task", "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message", "Next": "SpotRequestFailed" }, "SpotRequestFailed": { "Type": "Fail", "Error": "SpotRequestError", "Cause": "Spot price bidding too low" }, "NotifyRequestSuccess": { "Type" : "Task", "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message", "Next": "WaitTaskComplete" }, "WaitTaskComplete": { "Type" : "Wait", "Seconds" : 10, "Next" : "CheckTaskCompleted" }, "CheckTaskCompleted": { "Type" : "Task", "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:check_task_completed", "Next": "ChoiceTaskCompleted" }, "ChoiceTaskCompleted": { "Type" : "Choice", "Choices": [ { "Variable": "$.task_completed", "BooleanEquals": true, "Next": "NotifyTaskCompleted" }, { "Variable": "$.task_completed", "BooleanEquals": false, "Next": "WaitTaskComplete" } ], "Default": "WaitTaskComplete" }, "NotifyTaskCompleted":{ "Type": "Task", "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message", "Next": "WaitInstanceDelete" }, "WaitInstanceDelete": { "Type" : "Wait", "Seconds" : 1800, "Next" : "DeleteSpotInstance" }, "DeleteSpotInstance": { "Type": "Task", "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:delete_ec2_instance", "End": true } } }
- 判断分岐以外は直列に流してるだけ。
- 処理途中に生成されるID類はeventに追加しながら下流に流す
- S3作成とスポットインスタンスリクエストはParallelにしても良いかも(面倒くさいのでやってない..)
- 訓練完了から30分は削除せずに待つ。サーバに未練があればこの間に実行を停止する。
S3バケット作成
#!/usr/bin/env python # -*- coding: utf-8 -*- import boto3 import json import os def lookup(s3, bucket_name): try: s3.meta.client.head_bucket(Bucket=bucket_name) except botocore.exceptions.ClientError as e: error_code = int(e.response['Error']['Code']) if error_code == 404: return False return True def create_bucket(bucket_name): s3 = boto3.resource('s3') response = '' if not lookup(s3, bucket_name): response = s3.create_bucket(Bucket=bucket_name) return response def lambda_handler(event, context): response = create_bucket(event['exec_name']) return event
- eventからexec_nameを取り出してバケット名に
- その名前のバケットがなければ作る
スポットインスタンスのリクエスト
#!/usr/bin/env python # -*- coding: utf-8 -*- import boto3 import json import logging import base64 import os SPOT_PRICE = '0.8' REGION = 'ap-northeast-1' AMI_ID = 'ami-9999999f' KEY_NAME = 'your_keyname' INSTANCE_TYPE = 'g2.2xlarge' SECURITY_GRUOP_ID = ['sg-9999999'] def request_spot_instance(user_data): ec2_client = boto3.client('ec2', region_name = REGION ) response = ec2_client.request_spot_instances( SpotPrice = SPOT_PRICE, Type = 'one-time', LaunchSpecification = { 'ImageId': AMI_ID, 'KeyName': KEY_NAME, 'InstanceType': INSTANCE_TYPE, 'UserData': user_data, 'Placement':{}, 'SecurityGroupIds': SECURITY_GRUOP_ID } ) return response def lambda_handler(event, context): REPOSITORY_URL = event["repository_url"] REPOSITORY_NAME = event["repository_name"] BUCKET_NAME = event["exec_name"] shell='''#!/bin/sh sudo -s ubuntu cd /home/ubuntu sudo -u ubuntu mkdir /home/ubuntu/.aws sudo -u ubuntu mkdir /home/ubuntu/completed sudo -u ubuntu git clone {5} sudo -u ubuntu mkdir {0} sudo -u ubuntu mkdir {1} sudo -u ubuntu echo "[default]" >> /home/ubuntu/.aws/credentials sudo -u ubuntu echo "aws_access_key_id={2}" >> /home/ubuntu/.aws/credentials sudo -u ubuntu echo "aws_secret_access_key={3}" >> /home/ubuntu/.aws/credentials sudo -u ubuntu echo "*/5 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync {1} s3://{4} > /dev/null 2>&1" >> mycron sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp {1}/log s3://{4} > /dev/null 2>&1" >> mycron sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp /home/ubuntu/trace.log s3://{4} > /dev/null 2>&1" >> mycron sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync /home/ubuntu/completed s3://{4} > /dev/null 2>&1" >> mycron sudo -u ubuntu /usr/bin/crontab mycron sudo -u ubuntu /bin/rm /home/ubuntu/mycron PATH="/usr/local/cuda/bin:$PATH" LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH" sudo -u ubuntu cd /home/ubuntu/{6} sudo -u ubuntu touch trace.log sudo -u ubuntu echo `pwd` >> trace.log 2>&1 sudo -u ubuntu echo `which python` >> trace.log 2>&1 sudo -u ubuntu echo 'repository_name: {6}' >> trace.log 2>&1 sudo -u ubuntu echo 'dataget_command: {7}' >> trace.log 2>&1 sudo -u ubuntu echo 'exec_command: {8}' >> trace.log 2>&1 sudo -u ubuntu {7} > /dev/null 2>> trace.log sudo -u ubuntu echo `ls /home/ubuntu/data | wc` >> trace.log PATH="/usr/local/cuda/bin:$PATH" LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH" sudo -u ubuntu -i {8} >> trace.log 2>&1 sudo -u ubuntu touch /home/ubuntu/completed/completed.log ''' shell_code = shell.format( event["data_dir"], event["output_dir"], os.environ.get('S3_ACCESS_KEY_ID'), os.environ.get('S3_SECRET_ACCESS_KEY'), event["exec_name"], event["repository_url"], event["repository_name"], event["data_get_command"], event["exec_command"] ) user_data = base64.encodestring(shell_code.encode('utf-8')).decode('ascii') response = request_spot_instance(user_data) event["spot_instance_request_id"] = response["SpotInstanceRequests"][0]["SpotInstanceRequestId"] return event
- インスタンスタイプや入札価格は定数にして、StepFunction実行時の入力(event)からは引かないようにしている(eventはコードの実行条件のみにし、環境調達条件はLambda側に持たせるポリシーのつもり)
- AMIは、chainer、CUDA等はインストール完了いているものがある前提
- インスタンスをリクエストしたあとuser_dataをシェルスクリプトにして流し込んでる
- 大体の汚い処理はここのシェルスクリプトに凝縮されている
- S3へのupload系タスクはcronに登録
- その後、パスを通して訓練の開始
- S3_ACCESS_KEY_ID / S3_SECRET_ACCESS_KEYはIAMのwrite権限のある鍵をLambda Functionの環境変数に登録しておく。
- 実行時のログはtrace.logに出力 > これもS3に随時Up
- 実行完了後に、completed.logを出力。これがS3のバケットに入ると、StepFunctions側でタスク完了とみなされる
入札結果確認
def check_bidding_result(spot_instance_request_id): ec2_client = boto3.client('ec2', region_name = REGION ) response = ec2_client.describe_spot_instance_requests( SpotInstanceRequestIds = [spot_instance_request_id] ) return response def lambda_handler(event, context): response = check_bidding_result(event["spot_instance_request_id"]) event["request_result"] = (response['SpotInstanceRequests'][0]['Status']['Code']==u'fulfilled') if event["request_result"]: event["instance_id"] = response['SpotInstanceRequests'][0]['InstanceId'] return event
- スポットインスタンスリクエスト時に取得した'SpotInstanceRequests'から、入札の結果を確認する
通知
#!/usr/bin/env python # -*- coding: utf-8 -*- import boto3 import json import os TOPIC_ARN = 'arn:aws:sns:ap-northeast-1:9999999999:training_end_notification_mail' # Mail REGION = 'ap-northeast-1' def send_sms_message(event, context): sns = boto3.client('sns', region_name = REGION ) message = '' subject = '' if "completed" in event: subject = 'Training ended' message = '''task completed! result: https://console.aws.amazon.com/s3/home?bucket={0} ----- {1} '''.format(event["exec_name"], event) else: if event["request_result"]: subject = 'request fulfilled' message = ''' Spot Request Fulfilled! {0} '''.format(event["exec_name"]) else: subject = 'request failed' message = ''' Spot Request Fails! {0} '''.format(event["exec_name"]) response = sns.publish( TopicArn = TOPIC_ARN, Subject = subject, Message = message ) return response def lambda_handler(event, context): response = send_sms_message(event, context) return event
その他のLambda
- あとは特別なことはしていないリポジトリをご参照ください
リポジトリ
改良案とか、◯◯をXXでやらないなんて有りえない!とかあればお気軽に @mizti までコメントください (AWS今までちゃんと触ってこなかった勢なので話せれば嬉しいです)