Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC] Project 5: Integrate JAX with Kubeflow Training Operator to Support JAX Distributed Processes #2145

Open
sandipanpanda opened this issue Jun 13, 2024 · 2 comments

Comments

@sandipanpanda
Copy link
Contributor

sandipanpanda commented Jun 13, 2024

Introduction

During the Google Summer of Code (GSoC) 2024, I had the incredible opportunity to contribute to the Kubeflow open-source project by working on the integration of JAX with the Kubeflow Training Operator. The goal of this project was to provide a seamless and efficient way to run distributed computations on CPU using the JAX framework on Kubernetes. Throughout the summer, I worked to build out this feature by extending the Training Operator.

Project Overview

JAX, a powerful ML framework developed by Google, is highly valued for its flexibility and performance in large-scale distributed computations, especially with its native support for automatic differentiation and hardware accelerators like GPUs and TPUs. The Kubeflow Training Operator is a popular Kubernetes component that allows users to run distributed ML training jobs across various frameworks (such as TensorFlow, PyTorch, and XGBoost). However, until now, it lacked direct support for JAX.

Objectives

  • Create a Custom Resource for JAX (JaxJob):
    We needed to introduce a new Kubernetes Custom Resource Definition (CRD) for JAX, called JAXJob, that would allow users to define distributed JAX training jobs in Kubernetes clusters. This was crucial for enabling the integration of JAX into the Training Operator.

  • Update the Training Operator Controller:
    The Training Operator controller had to be updated to support the new JAXJob resource, handling the creation, scheduling, and management of distributed JAX training jobs on Kubernetes.

  • Enhance the Training Operator Python SDK:
    We aimed to extend the Training Operator Python SDK to provide easy-to-use APIs for data scientists and ML practitioners to define and launch JAXJob on Kubernetes, simplifying the process of running distributed JAX jobs.

Key Contributions

Progress and Achievements

By the end of the project, the following milestones were successfully achieved:

Future Work

  • Extending support for hardware accelerators like GPU and Cloud TPU would be beneficial to the community. Additional documentation and tutorials would also be beneficial to onboard new users to the JAX + Kubeflow ecosystem.
  • Add support for distributed training for JAX in Training V2 API

Lessons Learned

Throughout this project, I gained valuable insights into distributed systems, Kubernetes resource management, and the inner workings of machine learning frameworks like JAX. Some key takeaways include:

  • Kubernetes Deep Dive: I deepened my understanding of Kubernetes, particularly Custom Resource Definitions (CRDs) and controllers, which are the backbone of extending Kubernetes functionality.
  • Collaboration in Open Source: Working in a collaborative environment with experienced mentors was one of the highlights of this project. Their feedback and guidance helped me improve not only my technical skills but also my ability to communicate and collaborate effectively.
  • Distributed Training at Scale: This project gave me a deeper appreciation for the complexities of distributed training and the importance of tools like Kubernetes in managing large-scale machine learning workloads.

Conclusion

Integrating JAX with the Kubeflow Training Operator has been a challenging but rewarding experience. The project successfully enables distributed training for JAX on Kubernetes, providing an easy-to-use interface for data scientists and machine learning engineers.

I am grateful to my mentors — @tenzen-y, @andreyvelich, @terrytangyuan, and @shravan-achar — for their support and guidance throughout the summer.

I look forward to seeing how this feature evolves and benefits the Kubeflow community in the future.

@sandipanpanda sandipanpanda changed the title Tracking Issue: Integrate JAX in Kubeflow Training Operator [GSOC] Tracking Issue: Integrate JAX in Kubeflow Training Operator Jul 10, 2024
@sandipanpanda sandipanpanda mentioned this issue Jul 10, 2024
1 task
@sandipanpanda sandipanpanda changed the title [GSOC] Tracking Issue: Integrate JAX in Kubeflow Training Operator [GSoC] Project 5: Integrate JAX with Kubeflow Training Operator to Support JAX Distributed Processes Sep 19, 2024
@andreyvelich
Copy link
Member

@sandipanpanda Are we ready to close it ?
Adding the blog post about this work into Kubeflow website would be also nice: https://github.com/kubeflow/blog/pulls

@sandipanpanda
Copy link
Contributor Author

Yes, I'll am getting the draft ready, will open a PR to add the blog post about this work shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants