AWS Step FunctionsとLambdaでディープラーニングの訓練を全自動化する

動機とやったことの概要

詳細

Lambdaに付与する権限

たぶん以下くらいの権限がLambda実行時に必要。

AWSLambdaAMIExecutionRole
AmazonS3FullAccess
AmazonEC2SpotFleetRole
AWSLambdaBasicExecutionRole
AmazonSNSFullAccess
EC2ReadOnly ("ec2:DescribeSpotInstanceRequests"リソースへのアクセスを追加)

Step Functionsの入力

{
    "exec_name": "pix2pix-20161231",
    "repository_url": "https://github.com/mattya/chainer-pix2pix.git",
    "repository_name": "chainer-pix2pix",
    "data_dir": "/home/ubuntu/data",
    "output_dir": "/home/ubuntu/result",
    "data_get_command": "/home/ubuntu/.pyenv/shims/aws s3 cp s3://pix2pixfacade/ /home/ubuntu/data --recursive",
    "exec_command": "/home/ubuntu/.pyenv/shims/python /home/ubuntu/chainer-pix2pix/train_facade.py -g 0 -e 100 -i /home/ubuntu/data --out /home/ubuntu/result --snapshot_interval 10000"
}
変数名 説明
exec_name この実行の名前。バケット名にもなるため、アンダースコアを使わずkebab-case推奨
repository_url git cloneする対象のリポジトリURL
repository_name git cloneしたあと取得できるリポジトリ
data_dir データを格納するディレクト
output_dir 訓練結果等を格納するディレクト
data_get_command データを取得するなど、訓練開始前に実施する
exec_command 訓練実施コマンド

やってることの中身

Step Functionの定義

{
  "Comment" : "Machine learning execution with spot instance",
  "StartAt" : "CreateS3Bucket",
  "States"  : {
    "CreateS3Bucket": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:create_s3_bucket",
      "Next"      : "RequestSpotInstance"
    },
    "RequestSpotInstance": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:request_spot_instance",
      "Next"      : "WaitBidding"
    },
    "WaitBidding": {
      "Type"      : "Wait",
      "Seconds"   : 30,
      "Next"      : "CheckBiddingResult"
    },
    "CheckBiddingResult": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:check_bidding_result",
      "Next": "ChoiceBiddingResult"
    },
    "ChoiceBiddingResult": {
      "Type" : "Choice",
      "Choices": [
        {
          "Variable": "$.request_result",
          "BooleanEquals": true,
          "Next": "NotifyRequestSuccess"
        },
        {
          "Variable": "$.request_result",
          "BooleanEquals": false,
          "Next": "NotifyRequestFailed"
        }
      ],
      "Default": "NotifyRequestFailed"
    },
    "NotifyRequestFailed": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "SpotRequestFailed"
    },
    "SpotRequestFailed": {
          "Type": "Fail",
          "Error": "SpotRequestError",
          "Cause": "Spot price bidding too low"
    },
    "NotifyRequestSuccess": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "WaitTaskComplete"
    },
    "WaitTaskComplete": {
      "Type"      : "Wait",
      "Seconds"   : 10,
      "Next"      : "CheckTaskCompleted"
    },
    "CheckTaskCompleted": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:check_task_completed",
      "Next": "ChoiceTaskCompleted"
    },
    "ChoiceTaskCompleted": {
      "Type" : "Choice",
      "Choices": [
        {
          "Variable": "$.task_completed",
          "BooleanEquals": true,
          "Next": "NotifyTaskCompleted"
        },
        {
          "Variable": "$.task_completed",
          "BooleanEquals": false,
          "Next": "WaitTaskComplete"
        }
      ],
      "Default": "WaitTaskComplete"
    },
    "NotifyTaskCompleted":{
      "Type": "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "WaitInstanceDelete"
    },
    "WaitInstanceDelete": {
      "Type"      : "Wait",
      "Seconds"   : 1800,
      "Next"      : "DeleteSpotInstance"
    },
    "DeleteSpotInstance": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:delete_ec2_instance",
      "End": true
    }
  }
}
  • 判断分岐以外は直列に流してるだけ。
  • 処理途中に生成されるID類はeventに追加しながら下流に流す
  • S3作成とスポットインスタンスリクエストはParallelにしても良いかも(面倒くさいのでやってない..)
  • 訓練完了から30分は削除せずに待つ。サーバに未練があればこの間に実行を停止する。

S3バケット作成

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import os

def lookup(s3, bucket_name):
  try:
    s3.meta.client.head_bucket(Bucket=bucket_name)
  except botocore.exceptions.ClientError as e:
    error_code = int(e.response['Error']['Code'])

    if error_code == 404:
      return False

    return True

def create_bucket(bucket_name):
    s3 = boto3.resource('s3')
    response = ''
    if not lookup(s3, bucket_name):
       response = s3.create_bucket(Bucket=bucket_name)

    return response

def lambda_handler(event, context):
    response = create_bucket(event['exec_name'])
    return event
  • eventからexec_nameを取り出してバケット名に
  • その名前のバケットがなければ作る

スポットインスタンスのリクエスト

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import logging
import base64
import os

SPOT_PRICE = '0.8'
REGION = 'ap-northeast-1'
AMI_ID = 'ami-9999999f'
KEY_NAME = 'your_keyname'
INSTANCE_TYPE = 'g2.2xlarge'
SECURITY_GRUOP_ID = ['sg-9999999']

def request_spot_instance(user_data):
    ec2_client = boto3.client('ec2',
        region_name = REGION
    )
    response = ec2_client.request_spot_instances(
        SpotPrice = SPOT_PRICE,
        Type = 'one-time',
        LaunchSpecification = {
            'ImageId': AMI_ID,
            'KeyName': KEY_NAME,
            'InstanceType': INSTANCE_TYPE,
            'UserData': user_data,
            'Placement':{},
            'SecurityGroupIds': SECURITY_GRUOP_ID
        }
    )
    return response

def lambda_handler(event, context):
    REPOSITORY_URL  = event["repository_url"]
    REPOSITORY_NAME = event["repository_name"]
    BUCKET_NAME = event["exec_name"]

    shell='''#!/bin/sh
    sudo -s ubuntu
    cd /home/ubuntu
    sudo -u ubuntu mkdir /home/ubuntu/.aws
    sudo -u ubuntu mkdir /home/ubuntu/completed
    sudo -u ubuntu git clone {5}
    sudo -u ubuntu mkdir {0}
    sudo -u ubuntu mkdir {1}

    sudo -u ubuntu echo "[default]" >> /home/ubuntu/.aws/credentials
    sudo -u ubuntu echo "aws_access_key_id={2}" >> /home/ubuntu/.aws/credentials
    sudo -u ubuntu echo "aws_secret_access_key={3}" >> /home/ubuntu/.aws/credentials

    sudo -u ubuntu echo "*/5 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync {1} s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp {1}/log s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp /home/ubuntu/trace.log s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync /home/ubuntu/completed s3://{4} > /dev/null 2>&1" >> mycron

    sudo -u ubuntu /usr/bin/crontab mycron
    sudo -u ubuntu /bin/rm /home/ubuntu/mycron

    PATH="/usr/local/cuda/bin:$PATH"
    LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"

    sudo -u ubuntu cd /home/ubuntu/{6}

    sudo -u ubuntu touch trace.log
    sudo -u ubuntu echo `pwd` >> trace.log  2>&1
    sudo -u ubuntu echo `which python` >> trace.log  2>&1
    sudo -u ubuntu echo 'repository_name: {6}' >> trace.log 2>&1
    sudo -u ubuntu echo 'dataget_command: {7}' >> trace.log 2>&1
    sudo -u ubuntu echo 'exec_command: {8}' >> trace.log 2>&1
    sudo -u ubuntu {7}  > /dev/null 2>> trace.log
    sudo -u ubuntu echo `ls /home/ubuntu/data | wc` >> trace.log

    PATH="/usr/local/cuda/bin:$PATH"
    LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
    sudo -u ubuntu -i {8}  >> trace.log 2>&1
    sudo -u ubuntu touch /home/ubuntu/completed/completed.log
    '''

    shell_code = shell.format(
        event["data_dir"],
        event["output_dir"],
        os.environ.get('S3_ACCESS_KEY_ID'),
        os.environ.get('S3_SECRET_ACCESS_KEY'),
        event["exec_name"],
        event["repository_url"],
        event["repository_name"],
        event["data_get_command"],
        event["exec_command"]
        )
    user_data = base64.encodestring(shell_code.encode('utf-8')).decode('ascii')
    response = request_spot_instance(user_data)
    event["spot_instance_request_id"] = response["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
    return event
  • インスタンスタイプや入札価格は定数にして、StepFunction実行時の入力(event)からは引かないようにしている(eventはコードの実行条件のみにし、環境調達条件はLambda側に持たせるポリシーのつもり)
  • AMIは、chainer、CUDA等はインストール完了いているものがある前提
  • インスタンスをリクエストしたあとuser_dataをシェルスクリプトにして流し込んでる
  • 大体の汚い処理はここのシェルスクリプトに凝縮されている
    • S3へのupload系タスクはcronに登録
    • その後、パスを通して訓練の開始
  • S3_ACCESS_KEY_ID / S3_SECRET_ACCESS_KEYはIAMのwrite権限のある鍵をLambda Functionの環境変数に登録しておく。
  • 実行時のログはtrace.logに出力 > これもS3に随時Up
  • 実行完了後に、completed.logを出力。これがS3のバケットに入ると、StepFunctions側でタスク完了とみなされる

入札結果確認

def check_bidding_result(spot_instance_request_id):
    ec2_client = boto3.client('ec2',
        region_name = REGION
    )
    response = ec2_client.describe_spot_instance_requests(
      SpotInstanceRequestIds = [spot_instance_request_id]
    )
    return response

def lambda_handler(event, context):
    response = check_bidding_result(event["spot_instance_request_id"])
    event["request_result"] = (response['SpotInstanceRequests'][0]['Status']['Code']==u'fulfilled')

    if event["request_result"]:
        event["instance_id"] = response['SpotInstanceRequests'][0]['InstanceId']

    return event
  • スポットインスタンスリクエスト時に取得した'SpotInstanceRequests'から、入札の結果を確認する

通知

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import os

TOPIC_ARN = 'arn:aws:sns:ap-northeast-1:9999999999:training_end_notification_mail' # Mail
REGION = 'ap-northeast-1'


def send_sms_message(event, context):
    sns = boto3.client('sns',
        region_name = REGION
    )

    message = ''
    subject = ''
    if "completed" in event:
        subject = 'Training ended'
        message = '''task completed!
        result: https://console.aws.amazon.com/s3/home?bucket={0}

        -----
        {1}
        '''.format(event["exec_name"], event)
    else:
        if event["request_result"]:
            subject = 'request fulfilled'
            message = '''
            Spot Request Fulfilled! {0}
            '''.format(event["exec_name"])
        else:
            subject = 'request failed'
            message = '''
            Spot Request Fails! {0}
            '''.format(event["exec_name"])

        response = sns.publish(
            TopicArn = TOPIC_ARN,
            Subject = subject,
            Message = message
        )

    return response

def lambda_handler(event, context):
    response = send_sms_message(event, context)
    return event
  • 通知の宛先、通知手段は、事前にSNS側に登録し、Topic ARNを発行しておく
  • 作成されたインスタンスのIDはeventに追加して下流に流す

その他のLambda

  • あとは特別なことはしていないリポジトリをご参照ください

リポジトリ

github.com

改良案とか、◯◯をXXでやらないなんて有りえない!とかあればお気軽に @mizti までコメントください (AWS今までちゃんと触ってこなかった勢なので話せれば嬉しいです)

f:id:mizti:20170101205025p:plain