Skip to content

Commit

Permalink
[xgboost] Adds xgboost aarch64 support (#2659)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 17, 2023
1 parent 1257746 commit 8f841e6
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
79 changes: 79 additions & 0 deletions .github/workflows/native_s3_xgboost.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
name: Native S3 XGBoost

on:
workflow_dispatch:
xgb_version:
description: 'xgboost version'
required: false

jobs:
create-aarch64-runner:
if: github.repository == 'deepjavalibrary/djl'
runs-on: [ self-hosted, scheduler ]
steps:
- name: Create new Graviton instance
id: create_aarch64
run: |
cd /home/ubuntu/djl_benchmark_script/scripts
token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \
https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \
--fail \
| jq '.token' | tr -d '"' )
./start_instance.sh action_graviton $token djl
outputs:
aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }}

build-xgboost-jni-aarch64:
runs-on: [ self-hosted, aarch64 ]
container: centos:7
timeout-minutes: 30
needs: create-aarch64-runner
steps:
- uses: actions/checkout@v3
- name: Install Environment
run: |
yum -y update
yum -y install centos-release-scl-rh epel-release
yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel
curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz
tar xvfz cmake.tar.gz
ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake
cmake --version
pip3 install awscli --upgrade
- name: Set up JDK 11
uses: actions/setup-java@v3
with:
distribution: 'corretto'
java-version: 11
- name: Release JNI prep
run: |
XGBOOST_VERSION=${{ github.event.inputs.xgb_version }}
XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')}
git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION"
export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin
cd xgboost/jvm-packages
python3 create_jni.py
cd ../..
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v1-node16
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-2
- name: Copy files to S3 with the AWS CLI
run: |
XGBOOST_VERSION=${{ github.event.inputs.xgb_version }}
XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')}
aws s3 cp xgboost/lib/libxgboost4j.so s3://djl-ai/publish/xgboost/${XGBOOST_VERSION}/jnilib/linux/aarch64/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/xgboost/${XGBOOST_VERSION}/jnilib*"
stop-runners:
if: ${{ github.repository == 'deepjavalibrary/djl' && always() }}
runs-on: [ self-hosted, scheduler ]
needs: [ create-aarch64-runner, build-xgboost-jni-aarch64 ]
steps:
- name: Stop all instances
run: |
cd /home/ubuntu/djl_benchmark_script/scripts
instance_id=${{ needs.create-aarch64-runner.outputs.aarch64_instance_id }}
./stop_instance.sh $instance_id
18 changes: 18 additions & 0 deletions engines/ml/xgboost/build.gradle
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import java.util.zip.GZIPInputStream

group "ai.djl.ml.xgboost"
boolean isGpu = project.hasProperty("gpu")
def XGB_FLAVOR = isGpu ? "-gpu" : ""
Expand Down Expand Up @@ -34,6 +36,22 @@ dependencies {
}
}

compileJava.dependsOn(processResources)

processResources {
def jnilibDir = "${project.buildDir}/classes/java/main/lib/linux/aarch64/"
outputs.dir file(jnilibDir)
doLast {
def url = "https://publish.djl.ai/xgboost/${xgboost_version}/jnilib/linux/aarch64/libxgboost4j.so"
def file = new File("${jnilibDir}/libxgboost4j.so")
if (!file.exists()) {
project.logger.lifecycle("Downloading ${url}")
def downloadPath = new URL(url)
downloadPath.withInputStream { i -> file.withOutputStream { it << i } }
}
}
}

jar {
from {
(configurations.compileClasspath - configurations.exclusion).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public class XgbModelTest {

@BeforeClass
public void downloadXGBoostModel() throws IOException {
TestRequirements.notArm();
TestRequirements.notWindows();

Path modelDir = Paths.get("build/model");
Expand Down

0 comments on commit 8f841e6

Please sign in to comment.