Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip - Carry over task args in flytekit plugins decorators #2911

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
73b2b99
wip
fiedlerNr9 Oct 24, 2024
c9e2fea
wip
fiedlerNr9 Oct 24, 2024
569b83e
wip
fiedlerNr9 Oct 25, 2024
d7ccfd4
wip
fiedlerNr9 Oct 25, 2024
62af44b
wip
fiedlerNr9 Oct 25, 2024
59123ce
wip
fiedlerNr9 Oct 25, 2024
6537508
wip
fiedlerNr9 Oct 25, 2024
24061ed
wip
fiedlerNr9 Oct 25, 2024
6d08f10
wip
fiedlerNr9 Oct 25, 2024
7ba957c
wip
fiedlerNr9 Oct 25, 2024
c0e11e9
rename memray_profiling
fiedlerNr9 Oct 29, 2024
ce3ac68
finish readme
fiedlerNr9 Oct 29, 2024
a954b55
adjust memray_reporter_args type
fiedlerNr9 Oct 29, 2024
ea4d050
ruff check --fix
fiedlerNr9 Oct 29, 2024
ac27f86
ruff format
fiedlerNr9 Oct 29, 2024
681d75d
codespell
fiedlerNr9 Oct 29, 2024
ecb23d8
add flytekit-memray to pythonbuild workflows
fiedlerNr9 Nov 4, 2024
ae85925
allow memray.Tracker arguments in profiling
fiedlerNr9 Nov 4, 2024
a15a956
extend memray_profiling args description
fiedlerNr9 Nov 4, 2024
e05586b
spelling
fiedlerNr9 Nov 4, 2024
d2645b6
move tests
fiedlerNr9 Nov 4, 2024
1e3b8c5
move tests again :clown_face:
fiedlerNr9 Nov 4, 2024
9af3e8b
adjust README.md to not use PYMALLOC env variable
fiedlerNr9 Nov 4, 2024
a9efffd
Update plugins/flytekit-memray/flytekitplugins/memray/profiling.py
fiedlerNr9 Nov 6, 2024
0d951fc
Update plugins/flytekit-memray/flytekitplugins/memray/profiling.py
fiedlerNr9 Nov 6, 2024
a5007f3
Update plugins/flytekit-memray/flytekitplugins/memray/profiling.py
fiedlerNr9 Nov 6, 2024
3541f5e
add import sys
fiedlerNr9 Nov 6, 2024
fe64f84
wip - Carry over task args in flytekit plugins decorators
eapolinario Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ jobs:
- flytekit-kf-mpi
- flytekit-kf-pytorch
- flytekit-kf-tensorflow
- flytekit-memray
- flytekit-mlflow
- flytekit-mmcloud
- flytekit-modin
Expand Down
6 changes: 5 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@

decorated_fn = decorate_function(fn)

carried_task_args = {}
if hasattr(fn, "_fk_task_args"):
carried_task_args = fn._fk_task_args

Check warning on line 361 in flytekit/core/task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/task.py#L361

Added line #L361 was not covered by tests

task_instance = TaskPlugins.find_pythontask_plugin(type(task_config))(
task_config,
decorated_fn,
Expand All @@ -369,7 +373,7 @@
node_dependency_hints=node_dependency_hints,
task_resolver=task_resolver,
disable_deck=disable_deck,
enable_deck=enable_deck,
enable_deck=enable_deck if enable_deck is not None else carried_task_args.get("enable_deck"),
deck_fields=deck_fields,
docs=docs,
pod_template=pod_template,
Expand Down
54 changes: 54 additions & 0 deletions plugins/flytekit-memray/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Memray Profiling Plugin

Memray tracks and reports memory allocations, both in python code and in compiled extension modules.
This Memray Profiling plugin enables memory tracking on the Flyte task level and renders a memgraph profiling graph on Flyte Deck.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-memray
```

Example
```python
from flytekit import workflow, task, ImageSpec
from flytekitplugins.memray import memray_profiling
import time


image = ImageSpec(
name="memray_demo",
packages=["flytekitplugins_memray"],
registry="<your_cr_registry>",
)


def generate_data(n: int):
leak_list = []
for _ in range(n): # Arbitrary large number for demonstration
large_data = " " * 10**6 # 1 MB string
leak_list.append(large_data) # Keeps appending without releasing
time.sleep(0.1) # Slow down the loop to observe memory changes


@task(container_image=image, enable_deck=True)
@memray_profiling(memray_html_reporter="table")
def memory_usage(n: int) -> str:
generate_data(n=n)

return "Well"


@task(container_image=image, enable_deck=True)
@memray_profiling(trace_python_allocators=True, memray_reporter_args=["--leaks"])
def memory_leakage(n: int) -> str:
generate_data(n=n)

return "Well"


@workflow
def wf(n: int = 500):
memory_usage(n=n)
memory_leakage(n=n)
```
15 changes: 15 additions & 0 deletions plugins/flytekit-memray/flytekitplugins/memray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
.. currentmodule:: flytekitplugins.wandb

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

wandb_init
"""

from .profiling import memray_profiling

__all__ = ["memray_profiling"]
112 changes: 112 additions & 0 deletions plugins/flytekit-memray/flytekitplugins/memray/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import sys
import time
from typing import Callable, List, Optional

import memray
from flytekit import Deck
from flytekit.core.utils import ClassDecorator


class memray_profiling(ClassDecorator):
def __init__(
self,
task_function: Optional[Callable] = None,
native_traces: bool = False,
trace_python_allocators: bool = False,
follow_fork: bool = False,
memory_interval_ms: int = 10,
memray_html_reporter: str = "flamegraph",
memray_reporter_args: Optional[List[str]] = None,
):
"""Memray profiling plugin.
Args:
task_function (function, optional): The user function to be decorated. Defaults to None.
native_traces (bool): Whether or not to capture native stack frames, in addition to Python stack frames (see [Native tracking](https://bloomberg.github.io/memray/run.html#native-tracking))
trace_python_allocators (bool): Whether or not to trace Python allocators as independent allocations. (see [Python allocators](https://bloomberg.github.io/memray/python_allocators.html#python-allocators))
follow_fork (bool): Whether or not to continue tracking in a subprocess that is forked from the tracked process (see [Tracking across forks](https://bloomberg.github.io/memray/run.html#tracking-across-forks))
memory_interval_ms (int): How many milliseconds to wait between sending periodic resident set size updates.
By default, every 10 milliseconds a record is written that contains the current timestamp and the total number of bytes of virtual memory allocated by the process.
These records are used to create the graph of memory usage over time that appears at the top of the flame graph, for instance.
This parameter lets you adjust the frequency between updates, though you shouldn't need to change it.
memray_html_reporter (str): The name of the memray reporter which generates an html report.
Today there is only 'flamegraph' & 'table'.
memray_reporter_args (List[str], optional): A list of arguments to pass to the reporter commands.
See the [flamegraph](https://bloomberg.github.io/memray/flamegraph.html#reference)
and [table](https://bloomberg.github.io/memray/table.html#cli-reference) docs for details on supported arguments.
"""

if memray_html_reporter not in ["flamegraph", "table"]:
raise ValueError(f"{memray_html_reporter} is not a supported html reporter.")

if memray_reporter_args is not None and not all(
isinstance(arg, str) and "--" in arg for arg in memray_reporter_args
):
raise ValueError(
f"unrecognized arguments for {memray_html_reporter} reporter. Please check https://bloomberg.github.io/memray/{memray_html_reporter}.html"
)

carried_task_args = {}
if hasattr(task_function, "_fk_task_args"):
carried_task_args = task_function._fk_task_args
self._fk_task_args = {**carried_task_args, **{"enable_deck": True}}
self.native_traces = native_traces
self.trace_python_allocators = trace_python_allocators
self.follow_fork = follow_fork
self.memory_interval_ms = memory_interval_ms
self.dir_name = "memray_bin"
self.memray_html_reporter = memray_html_reporter
self.memray_reporter_args = memray_reporter_args if memray_reporter_args else []

super().__init__(
task_function,
native_traces=native_traces,
trace_python_allocators=trace_python_allocators,
follow_fork=follow_fork,
memory_interval_ms=memory_interval_ms,
memray_html_reporter=memray_html_reporter,
memray_reporter_args=memray_reporter_args,
)

def execute(self, *args, **kwargs):
if not os.path.exists(self.dir_name):
os.makedirs(self.dir_name)

bin_filepath = os.path.join(
self.dir_name,
f"{self.task_function.__name__}.{time.strftime('%Y%m%d%H%M%S')}.bin",
)

with memray.Tracker(
bin_filepath,
native_traces=self.native_traces,
trace_python_allocators=self.trace_python_allocators,
follow_fork=self.follow_fork,
memory_interval_ms=self.memory_interval_ms,
):
output = self.task_function(*args, **kwargs)

self.generate_flytedeck_html(reporter=self.memray_html_reporter, bin_filepath=bin_filepath)

return output

def generate_flytedeck_html(self, reporter, bin_filepath):
html_filepath = bin_filepath.replace(
self.task_function.__name__, f"{reporter}.{self.task_function.__name__}"
).replace(".bin", ".html")

memray_reporter_args_str = " ".join(self.memray_reporter_args)

if (
os.system(
f"{sys.executable} -m memray {reporter} -o {html_filepath} {memray_reporter_args_str} {bin_filepath}"
)
== 0
):
with open(html_filepath, "r", encoding="utf-8") as file:
html_content = file.read()

Deck(f"Memray {reporter.capitalize()}", html_content)

def get_extra_config(self):
return {}
37 changes: 37 additions & 0 deletions plugins/flytekit-memray/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from setuptools import setup

PLUGIN_NAME = "memray"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.12.0", "memray"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package enables memory profiling for tasks with memray",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.8",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
40 changes: 40 additions & 0 deletions plugins/flytekit-memray/tests/test_memray_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from unittest.mock import Mock, patch
import pytest
from flytekit import task, current_context
from flytekitplugins.memray import memray_profiling


# Notice how we no longer set enable_deck=True
@task
@memray_profiling
def heavy_compute(i: int) -> int:
return i + 1


def test_local_exec():
heavy_compute(i=7)
assert (
len(current_context().decks) == 6
) # memray flamegraph, timeline, input, and output, source code, dependencies


def test_errors():
reporter = "summary"
with pytest.raises(
ValueError, match=f"{reporter} is not a supported html reporter."
):
memray_profiling(memray_html_reporter=reporter)

reporter = "flamegraph"
with pytest.raises(
ValueError,
match=f"unrecognized arguments for {reporter} reporter. Please check https://bloomberg.github.io/memray/{reporter}.html",
):
memray_profiling(memray_reporter_args=["--leaks", "trash"])

reporter = "flamegraph"
with pytest.raises(
ValueError,
match=f"unrecognized arguments for {reporter} reporter. Please check https://bloomberg.github.io/memray/{reporter}.html",
):
memray_profiling(memray_reporter_args=[0, 1, 2])
Loading