Traces of thoughts

View on GitHub

Metaflow Deep Dive (2) - Runtime

img.png

Previous Post(s):

Metaflow Deep Dive (1) - Static Analysis

In my last post, I covered Metaflow’s static graph analysis. Now let’s move forward with the workflow runtime.

Core of the Workflow Runtime

Metaflow implemented its own process worker pool based on *NIX system polling mechanism. This documentation explains why Metaflow made such choice.

Revisit the Entry Point

As previously mentioned, FlowSpec constructor starts the workflow with cli.main(self), which is defined below:

# metaflow/cli.py

def main(flow, args=None, handle_exceptions=True, entrypoint=None):
    import warnings

    warnings.filterwarnings("ignore")
    if entrypoint is None:
        entrypoint = [sys.executable, sys.argv[0]]

    state = CliState(flow)
    state.entrypoint = entrypoint

    try:
        if args is None:
            start(auto_envvar_prefix="METAFLOW", obj=state)
        else:
            try:
                start(args=args, obj=state, auto_envvar_prefix="METAFLOW")
            except SystemExit as e:
                return e.code
    except MetaflowException as x:
        if handle_exceptions:
            print_metaflow_exception(x)
            sys.exit(1)
        else:
            raise
    except Exception as x:
        if handle_exceptions:
            print_unknown_exception(x)
            sys.exit(1)
        else:
            raise
    finally:
        if hasattr(state, "monitor") and state.monitor is not None:
            state.monitor.terminate()
        if hasattr(state, "event_logger") and state.event_logger is not None:
            state.event_logger.terminate()
  1. The main process packs the workflow instance and entry point (interpreter binary and script paths) into a state object.
  2. Calls start(auto_envvar_prefix="METAFLOW", obj=state) to kick off the workflow.
  3. The finally block (added in recent versions) ensures the monitor and event logger sidecars are properly terminated even if an exception occurs.

Now let’s move to the start function.

start

This is when things get slightly more complex, as Metaflow tightly couples with the click library. It can be a bit tricky to get a clean view.

Decorators

First, it has a lot of decorators, most of which come from click.

# metaflow/cli.py

@decorators.add_decorator_options
@config_options
@click.command(
    cls=LazyPluginCommandCollection,
    sources=[cli],
    lazy_sources=plugins.get_plugin_cli_path(),
    invoke_without_command=True,
)
# Omitted all @click.option decorators.
@click.pass_context
def start(
        ctx,
        ...
):
    ...

Definition

Now moving on to the start function body.

# metaflow/cli.py

# Decorators are omitted.
def start(
        ctx,
        quiet=False,
        metadata=None,
        environment=None,
        force_rebuild_environments=False,
        datastore=None,
        datastore_root=None,
        decospecs=None,
        package_suffixes=None,
        pylint=None,
        event_logger=None,
        monitor=None,
        local_config_file=None,
        config=None,
        config_value=None,
        mode=None,
        **deco_options
):
    if quiet:
        echo = echo_dev_null
    else:
        echo = echo_always

    ctx.obj.version = metaflow_version.get_version()
    version = ctx.obj.version
    if use_r():
        version = metaflow_r_version()

    echo("Metaflow %s" % version, fg="magenta", bold=True, nl=False)
    echo(" executing *%s*" % ctx.obj.flow.name, fg="magenta", nl=False)
    echo(" for *%s*" % resolve_identity(), fg="magenta")

    cli_args._set_top_kwargs(ctx.params)
    ctx.obj.echo = echo
    ctx.obj.echo_always = echo_always
    ctx.obj.is_quiet = quiet
    ctx.obj.logger = logger
    ctx.obj.pylint = pylint
    ctx.obj.check = functools.partial(_check, echo)
    ctx.obj.top_cli = cli
    ctx.obj.package_suffixes = package_suffixes.split(",")
    ctx.obj.spin_mode = mode == "spin"

    ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0]

    if datastore_root is None:
        datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config(
            ctx.obj.echo
        )
    if datastore_root is None:
        raise CommandException(
            "Could not find the location of the datastore -- did you correctly set the "
            "METAFLOW_DATASTORE_SYSROOT_%s environment variable?" % datastore.upper()
        )

    ctx.obj.datastore_impl.datastore_root = datastore_root
    FlowDataStore.default_storage_impl = ctx.obj.datastore_impl

    # Process config decorators and possibly mutate the flow class
    config_options = config or config_value
    # ... config handling for resume, flow mutators, etc. ...
    new_cls = ctx.obj.flow._process_config_decorators(config_options)
    if new_cls:
        ctx.obj.flow = new_cls(use_cli=False)

    ctx.obj.graph = ctx.obj.flow._graph

    ctx.obj.environment = [
        e for e in ENVIRONMENTS + [MetaflowEnvironment] if e.TYPE == environment
    ][0](ctx.obj.flow)
    ctx.obj.environment.validate_environment(ctx.obj.logger, datastore)

    ctx.obj.event_logger = LOGGING_SIDECARS[event_logger](
        flow=ctx.obj.flow, env=ctx.obj.environment
    )
    ctx.obj.monitor = MONITOR_SIDECARS[monitor](
        flow=ctx.obj.flow, env=ctx.obj.environment
    )
    ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0](
        ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor
    )

    ctx.obj.flow_datastore = FlowDataStore(
        ctx.obj.flow.name,
        ctx.obj.environment,
        ctx.obj.metadata,
        ctx.obj.event_logger,
        ctx.obj.monitor,
    )

    # ... spin mode setup if applicable ...

    ctx.obj.event_logger.start()
    ctx.obj.monitor.start()

    decorators._init(ctx.obj.flow)

    # It is important to initialize flow decorators early as some of the
    # things they provide may be used by some of the objects initialized after.
    decorators._init_flow_decorators(
        ctx.obj.flow,
        ctx.obj.graph,
        ctx.obj.environment,
        ctx.obj.flow_datastore,
        ctx.obj.metadata,
        ctx.obj.logger,
        echo,
        deco_options,
        ctx.obj.is_spin,
        ctx.obj.skip_decorators,
    )

    ctx.obj.tl_decospecs = list(decospecs or [])

    current._set_env(flow=ctx.obj.flow, is_running=False)
    parameters.set_parameter_context(
        ctx.obj.flow.name, ctx.obj.echo, ctx.obj.flow_datastore, ...
    )

    if ctx.invoked_subcommand not in ("run", "resume"):
        # run/resume are special cases because they can add more decorators with --with,
        # so they have to take care of themselves.
        # ... attach and init step decorators ...
        ctx.obj.package = None
    if ctx.invoked_subcommand is None:
        ctx.invoke(check)

Despite being chunky, the code does the following:

On run

This is where the main show begins. Note that in recent versions of Metaflow, the run command has been moved from cli.py into metaflow/cli_components/run_cmds.py for better code organization.

# metaflow/cli_components/run_cmds.py

# decorators are omitted for brevity.
def run(
        obj,
        tags=None,
        max_workers=None,
        max_num_splits=None,
        max_log_size=None,
        decospecs=None,
        run_id_file=None,
        runner_attribute_file=None,
        user_namespace=None,
        **kwargs
):
    if user_namespace is not None:
        namespace(user_namespace or None)
    before_run(obj, tags, decospecs)

    runtime = NativeRuntime(
        obj.flow,
        obj.graph,
        obj.flow_datastore,
        obj.metadata,
        obj.environment,
        obj.package,
        obj.logger,
        obj.entrypoint,
        obj.event_logger,
        obj.monitor,
        max_workers=max_workers,
        max_num_splits=max_num_splits,
        max_log_size=max_log_size * 1024 * 1024,
    )
    write_latest_run_id(obj, runtime.run_id)
    write_file(run_id_file, runtime.run_id)

    obj.flow._set_constants(obj.graph, kwargs, obj.config_options)
    current._update_env({"run_id": runtime.run_id})

    runtime.print_workflow_info()
    runtime.persist_constants()
    if runner_attribute_file:
        # Write run metadata for the Runner API
        with open(runner_attribute_file, "w", encoding="utf-8") as f:
            json.dump({"run_id": runtime.run_id, "flow_name": obj.flow.name, ...}, f)
    with runtime.run_heartbeat():
        runtime.execute()

What happens here is:

  1. before_run(...) performs all the gate-keeping checks such as DAG validation and pylint. It also handles attaching decorators, initializing step decorators, and creating the code package.
  2. obj.flow._set_constants(obj.graph, kwargs, obj.config_options) sets the Parameter values as attributes of the workflow instance. It now also receives graph and config_options to handle config parameters.
  3. A NativeRuntime instance is created using all the global states carried over by ctx.obj.
    1. runtime.print_workflow_info() prints a summary of the workflow’s configuration.
    2. runtime.persist_constants() persists workflow parameters before the start step fires off.
    3. runtime.execute() starts the workflow execution, now wrapped in runtime.run_heartbeat() context manager that manages the heartbeat lifecycle.
  4. The runner_attribute_file option supports the Runner API (introduced in v2.12.0), which enables running flows programmatically from notebooks and Python scripts.

Workflow Runtime

As mentioned earlier, NativeRuntime is the core component of a workflow to orchestrate the execution of the tasks.

At its heart, a runtime runs an event loop that implements producer-consumer pattern, where the producer is the workflow topology and the consumer is the worker processes. Runtime is the pump that drives the workflow execution.

NativeRuntime constructor

# metaflow/runtime.py

class NativeRuntime(object):
    def __init__(
            self,
            flow,
            graph,
            flow_datastore,
            metadata,
            environment,
            package,
            logger,
            entrypoint,
            event_logger,
            monitor,
            run_id=None,
            clone_run_id=None,
            clone_only=False,
            reentrant=False,
            steps_to_rerun=None,
            max_workers=MAX_WORKERS,
            max_num_splits=MAX_NUM_SPLITS,
            max_log_size=MAX_LOG_SIZE,
            resume_identifier=None,
            skip_decorator_hooks=False,
    ):

        if run_id is None:
            self._run_id = metadata.new_run_id()
        else:
            self._run_id = run_id
            metadata.register_run_id(run_id)

        self._flow = flow
        self._graph = graph
        self._flow_datastore = flow_datastore
        self._metadata = metadata
        self._environment = environment
        self._logger = logger
        self._max_workers = max_workers
        self._active_tasks = dict()  # Per-step tracking: {step_name: [running, done]}
        self._active_tasks[0] = 0    # Total active workers count
        self._unprocessed_steps = set([n.name for n in self._graph])
        self._max_num_splits = max_num_splits
        self._max_log_size = max_log_size
        self._params_task = None
        self._entrypoint = entrypoint
        self.event_logger = event_logger
        self._monitor = monitor
        self._resume_identifier = resume_identifier

        self._clone_run_id = clone_run_id
        self._clone_only = clone_only
        self._cloned_tasks = []
        self._ran_or_scheduled_task_index = set()
        self._reentrant = reentrant
        self._skip_decorator_hooks = skip_decorator_hooks

        # When resuming, propagate steps_to_rerun to all downstream steps
        self._steps_to_rerun = steps_to_rerun or {}
        for step_name in self._graph.sorted_nodes:
            if step_name in self._steps_to_rerun:
                out_funcs = self._graph[step_name].out_funcs or []
                for next_step in out_funcs:
                    self._steps_to_rerun.add(next_step)

        self._origin_ds_set = None
        if clone_run_id:
            # Resume logic: clone successful tasks from clone_run_id,
            # re-run unsuccessful or not-run steps.
            logger(
                "Gathering required information to resume run "
                "(this may take a bit of time)..."
            )
            self._origin_ds_set = TaskDataStoreSet(
                flow_datastore,
                clone_run_id,
                prefetch_data_artifacts=PREFETCH_DATA_ARTIFACTS,
            )
        self._run_queue = []
        self._poll = procpoll.make_poll()
        self._workers = {}  # fd -> subprocess mapping
        self._finished = {}
        self._is_cloned = {}
        self._control_num_splits = {}  # control_task -> num_splits mapping

        if not self._skip_decorator_hooks:
            for step in flow:
                for deco in step.decorators:
                    deco.runtime_init(flow, graph, package, self._run_id)

Key differences from earlier versions:

  1. A unique run_id is generated for the new run. As for a local run, it simply uses the current timestamp.
  2. _active_tasks is now a dictionary tracking per-step running/completed counts (previously just a simple counter _num_active_workers). _unprocessed_steps tracks steps that haven’t started yet.
  3. New resume-related parameters: clone_only (clone without re-executing), reentrant (parallel resume support), steps_to_rerun (specific steps to force re-run and all their downstream steps), resume_identifier.
  4. skip_decorator_hooks allows skipping decorator lifecycle hooks (used by the spin mode).
  5. A ProcPoll instance is created. It is used to track (by polling at a fixed interval) the worker processes.

Dissecting NativeRuntime.execute

# metaflow/runtime.py

class NativeRuntime(object):
    def execute(self):
        if len(self._cloned_tasks) > 0:
            # Resume: process pre-cloned tasks first
            self._run_queue = []
            self._active_tasks[0] = 0
        else:
            if self._params_task:
                self._queue_push("start", {"input_paths": [self._params_task.path]})
            else:
                self._queue_push("start", {})

        progress_tstamp = time.time()
        with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file:
            # Write flow config values to a temp file for worker processes
            config_value = dump_config_values(self._flow)
            if config_value:
                json.dump(config_value, config_file)
                config_file.flush()
                self._config_file_name = config_file.name
            else:
                self._config_file_name = None
            try:
                # main scheduling loop
                exception = None
                while self._run_queue or self._active_tasks[0] > 0 or self._cloned_tasks:
                    if self._cloned_tasks:
                        # Process cloned tasks from resume
                        finished_tasks = ...  # dedup and process cloned tasks
                        self._cloned_tasks = []
                    else:
                        # 1. are any of the current workers finished?
                        finished_tasks = list(self._poll_workers())
                    # 2. push new tasks triggered by the finished tasks to the queue
                    self._queue_tasks(finished_tasks)
                    # 3. if there are available worker slots, pop and start tasks from the queue.
                    self._launch_workers()

                    if time.time() - progress_tstamp > PROGRESS_INTERVAL:
                        progress_tstamp = time.time()
                        # Detailed per-step progress reporting
                        tasks_print = ", ".join(
                            ["%s (%d running; %d done)" % (k, v[0], v[1])
                             for k, v in self._active_tasks.items()
                             if k != 0 and v[0] > 0]
                        )
                        # ... log active tasks, queued tasks, unprocessed steps ...

            except KeyboardInterrupt as ex:
                self._logger("Workflow interrupted. ...", system_msg=True, bad=True)
                self._killall()
                exception = ex
                raise
            except Exception as ex:
                self._logger("Workflow failed.", system_msg=True, bad=True)
                self._killall()
                exception = ex
                raise
            finally:
                if not self._skip_decorator_hooks:
                    for step in self._flow:
                        for deco in step.decorators:
                            deco.runtime_finished(exception)
                self._run_exit_hooks()

        # assert that end was executed and it was successful
        if ("end", (), ()) in self._finished:
            if self._run_url:
                self._logger("Done! See the run in the UI at %s" % self._run_url, system_msg=True)
            else:
                self._logger("Done!", system_msg=True)
        elif self._clone_only:
            self._logger("Clone-only resume complete...", system_msg=True)
        else:
            raise MetaflowInternalError("The *end* step was not successful by the end of flow.")

Summary:

  1. The event loop is not thread-safe and should be run by a single thread.
  2. Flow config values are serialized to a temporary file that worker processes can read via --local-config-file.
  3. self._queue_push("start", ..) enqueues the seeding start task to the queue.
  4. The while loop is the main event loop. It now has three termination conditions: the queue is empty, no active workers, and no remaining cloned tasks (for resume). The loop:
    1. Processes cloned tasks from resume if present, otherwise polls workers for finished tasks.
    2. self._queue_tasks(finished_tasks) pushes the child steps of the finished tasks to the queue.
    3. self._launch_workers(): for each step remaining in the queue, dequeues it and creates a new worker process (if the worker pool capacity is not reached) to run.
    4. Progress reporting is now more detailed, showing per-step running/completed counts and listing unprocessed steps.
    5. Should there be an unexpected exception, the loop terminates all in-flight workers and exits.
  5. The finished check now uses a 3-tuple ("end", (), ()) (step name, foreach stack, iteration stack) instead of ("end", ()), reflecting the richer task identity tracking.
  6. _run_exit_hooks() runs any registered exit hooks before cleanup.

Alright, that’s it! Now let’s move on to the rest of the code.

Notes on _poll_workers()

# metaflow/runtime.py

class NativeRuntime(object):
    def _poll_workers(self):
        if self._workers:
            for event in self._poll.poll(POLL_TIMEOUT):
                worker = self._workers.get(event.fd)
                if worker:
                    if event.can_read:
                        worker.read_logline(event.fd)
                    if event.is_terminated:
                        returncode = worker.terminate()

                        for fd in worker.fds():
                            self._poll.remove(fd)
                            del self._workers[fd]
                        step_counts = self._active_tasks[worker.task.step]
                        step_counts[0] -= 1
                        step_counts[1] += 1
                        self._active_tasks[0] -= 1

                        task = worker.task
                        if returncode:
                            # worker did not finish successfully
                            if worker.cleaned or returncode == METAFLOW_EXIT_DISALLOW_RETRY:
                                self._logger(
                                    "This failed task will not be retried.",
                                    system_msg=True,
                                )
                            else:
                                if task.retries < task.user_code_retries + task.error_retries:
                                    self._retry_worker(worker)
                                else:
                                    raise TaskFailed(task)
                        else:
                            # worker finished successfully
                            yield task
  1. Keeps polling the worker pool at a fixed rate (POLL_TIMEOUT) for any I/O events.
  2. Once an event emerges, it retrieves the associated worker (by fd).
  3. If that event is in terminated state (task finished or failed), the worker is removed from the pool. The per-step task counts in _active_tasks are updated (decrement running, increment done) along with the global active count.
  4. Depending on the return code, the task is either returned (finished) or retried (failed).

Task

Task represents a single step in the workflow. It packs all data that is needed to be plugged into a worker.

Worker

Worker is the execution unit of the workflow. It is responsible for executing a single task (step) in a dedicated process.

Constructor

# metaflow/runtime.py

class Worker(object):
    def __init__(
        self,
        task,
        max_logs_size,
        config_file_name,
        orig_flow_datastore=None,
        spin_pathspec=None,
        artifacts_module=None,
        persist=True,
        skip_decorators=False,
    ):
        self.task = task
        self._config_file_name = config_file_name
        self._orig_flow_datastore = orig_flow_datastore
        self._spin_pathspec = spin_pathspec
        self._artifacts_module = artifacts_module
        self._skip_decorators = skip_decorators
        self._persist = persist
        self._proc = self._launch()

        if task.retries > task.user_code_retries:
            self.task.log(
                "Task fallback is starting to handle the failure.",
                system_msg=True,
                pid=self._proc.pid,
            )
        elif not task.is_cloned:
            suffix = " (retry)." if task.retries else "."
            self.task.log(
                "Task is starting" + suffix, system_msg=True, pid=self._proc.pid
            )

        self._stdout = TruncatedBuffer("stdout", max_logs_size)
        self._stderr = TruncatedBuffer("stderr", max_logs_size)

        self._logs = {
            self._proc.stderr.fileno(): (self._proc.stderr, self._stderr),
            self._proc.stdout.fileno(): (self._proc.stdout, self._stdout),
        }

        self._encoding = sys.stdout.encoding or "UTF-8"
        self.killed = False   # Forcibly killed by the master process via SIGKILL
        self.cleaned = False  # Shutting down, queried by the runtime for state

The constructor now accepts several additional parameters to support new features:

Once a worker is instantiated, an underlying process is spawned to run the task.

Process Creation

# metaflow/runtime.py

class Worker(object):
    def _launch(self):
        args = CLIArgs(
            self.task,
            orig_flow_datastore=self._orig_flow_datastore,
            spin_pathspec=self._spin_pathspec,
            artifacts_module=self._artifacts_module,
            persist=self._persist,
            skip_decorators=self._skip_decorators,
        )
        env = dict(os.environ)

        if self.task.clone_run_id:
            args.command_options["clone-run-id"] = self.task.clone_run_id

        if self.task.is_cloned and self.task.clone_origin:
            args.command_options["clone-only"] = self.task.clone_origin
            args.top_level_options["event-logger"] = "nullSidecarLogger"
            args.top_level_options["monitor"] = "nullSidecarMonitor"
        else:
            # decorators may modify the CLIArgs object in-place
            for deco in self.task.decos:
                deco.runtime_step_cli(
                    args,
                    self.task.retries,
                    self.task.user_code_retries,
                    self.task.ubf_context,
                )

        if self._config_file_name:
            args.top_level_options["local-config-file"] = self._config_file_name
        env.update(args.get_env())
        env["PYTHONUNBUFFERED"] = "x"
        tracing.inject_tracing_vars(env)
        cmdline = args.get_args()
        debug.subcommand_exec(cmdline)
        return subprocess.Popen(
            cmdline,
            env=env,
            bufsize=1,
            stdin=subprocess.PIPE,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
        )

This unveils how a worker process is spawned:

  1. args = CLIArgs(self.task, ...) constructs the command line arguments from the task specification. It now also receives spin mode and artifact persistence parameters. Notably:
    1. Entry point of the arguments are [python_interpreter_path, main_script_path].
    2. It adds “step” to the CLI args as a subcommand, so that the step function will be called (by click framework) in the worker process to run that specific step. (The step command has been moved to metaflow/cli_components/step_cmd.py.)
  2. For cloned tasks, both the event logger and monitor sidecars are disabled (not just the monitor as before).
  3. The config file path is passed via --local-config-file top-level option.
  4. tracing.inject_tracing_vars(env) injects OpenTelemetry tracing environment variables for distributed tracing.
  5. subprocess.Popen(..) spawns a process object with the CLI arguments and env vars.

step command

This is where a step code is retrieved and executed. Like the run command, it has been moved to its own module in recent versions.

# metaflow/cli_components/step_cmd.py

@click.command(help="Internal command to execute a single task.", hidden=True)
@click.argument("step-name")
# Omitted @click.option decorators.
@click.pass_context
def step(
        ctx,
        step_name,
        # Omitted other kwargs including input_paths_filename, num_parallel, etc.
):
    if ubf_context == "none":
        ubf_context = None
    if opt_namespace is not None:
        namespace(opt_namespace)

    func = None
    try:
        func = getattr(ctx.obj.flow, step_name)
    except:
        raise CommandException("Step *%s* doesn't exist." % step_name)
    if not func.is_step:
        raise CommandException("Function *%s* is not a step." % step_name)
    echo("Executing a step, *%s*" % step_name, fg="magenta", bold=False)

    step_kwargs = ctx.params
    step_kwargs.pop("step_name", None)
    step_kwargs = dict(
        [(k[4:], v) if k.startswith("opt_") else (k, v) for k, v in step_kwargs.items()]
    )
    cli_args._set_step_kwargs(step_kwargs)

    ctx.obj.metadata.add_sticky_tags(tags=opt_tag)
    # Support reading input paths from a file (for large foreach splits)
    if not input_paths and input_paths_filename:
        with open(input_paths_filename, mode="r", encoding="utf-8") as f:
            input_paths = f.read().strip(" \n\"'")
    paths = decompress_list(input_paths) if input_paths else []

    task = MetaflowTask(
        ctx.obj.flow,
        ctx.obj.flow_datastore,
        ctx.obj.metadata,
        ctx.obj.environment,
        ctx.obj.echo,
        ctx.obj.event_logger,
        ctx.obj.monitor,
        ubf_context,
    )
    if clone_only:
        task.clone_only(step_name, run_id, task_id, clone_only, retry_count)
    else:
        task.run_step(
            step_name,
            run_id,
            task_id,
            clone_run_id,
            paths,
            split_index,
            retry_count,
            max_user_code_retries,
        )

    echo("Success", fg="green", bold=True, indent=True)

Key takeaways:

  1. The step command is now in metaflow/cli_components/step_cmd.py and marked hidden=True since it’s an internal command not meant for direct user invocation.
  2. func = getattr(ctx.obj.flow, step_name) fetches the step function from the flow object. The old decorators._attach_decorators_to_step call has been removed — decorator attachment is now handled earlier in the lifecycle.
  3. Input paths can now be read from a file (input_paths_filename) instead of being passed on the command line, which is important for large foreach splits that could exceed command-line length limits.
  4. A MetaflowTask object is created to execute the step.

MetaflowTask

A MetaflowTask prepares a Flow instance for execution of a single step.

MetaflowTask.run_step

This is the place where a step function is eventually executed. The method is chunky, so I’ll divide it into smaller pieces.

# metaflow/task.py
def run_step(
        self,
        step_name,
        run_id,
        task_id,
        origin_run_id,
        input_paths,
        split_index,
        retry_count,
        max_user_code_retries,
        whitelist_decorators=None,
        persist=True,
):
    if run_id and task_id:
        self.metadata.register_run_id(run_id)
        self.metadata.register_task_id(run_id, step_name, task_id, retry_count)
    else:
        raise MetaflowInternalError(
            "task.run_step needs a valid run_id and task_id"
        )

    if retry_count >= MAX_ATTEMPTS:
        raise MetaflowInternalError(
            "Too many task attempts (%d)! MAX_ATTEMPTS exceeded." % retry_count
        )

    metadata_tags = ["attempt_id:{0}".format(retry_count)]
    metadata = [
        MetaDatum(field="attempt", value=str(retry_count), type="attempt", tags=metadata_tags),
        MetaDatum(field="origin-run-id", value=str(origin_run_id), type="origin-run-id", tags=metadata_tags),
        MetaDatum(field="ds-type", value=self.flow_datastore.TYPE, type="ds-type", tags=metadata_tags),
        MetaDatum(field="ds-root", value=self.flow_datastore.datastore_root, type="ds-root", tags=metadata_tags),
    ]
    # OpenTelemetry trace ID for distributed tracing
    trace_id = get_trace_id()
    if trace_id:
        metadata.append(
            MetaDatum(field="otel-trace-id", value=trace_id, type="trace-id", tags=metadata_tags)
        )

    step_func = getattr(self.flow, step_name)
    decorators = step_func.decorators
    if self.orig_flow_datastore:
        # Spin mode: optionally whitelist specific decorators
        decorators = (
            []
            if not whitelist_decorators
            else [deco for deco in decorators if deco.name in whitelist_decorators]
        )

    node = self.flow._graph[step_name]
    join_type = None
    if node.type == "join":
        join_type = self.flow._graph[node.split_parents[-1]].type

New in this version:

# metaflow/task.py

    # 1. initialize output datastore
    output = self.flow_datastore.get_task_datastore(
        run_id, step_name, task_id, attempt=retry_count, mode="w", persist=persist
    )
    output.init_task()

    if input_paths:
        # 2. initialize input datastores
        inputs = self._init_data(run_id, join_type, input_paths)

        # 3. initialize foreach state
        self._init_foreach(step_name, join_type, inputs, split_index)

        # 4. initialize iteration state (new: supports split-switch recursive steps)
        is_recursive_step = (
            node.type == "split-switch" and step_name in node.out_funcs
        )
        self._init_iteration(step_name, inputs, is_recursive_step)

        # 5. collect foreach stack metadata for tracking
        # ... foreach_stack_formatted, foreach_execution_path ...

    self.metadata.register_metadata(run_id, step_name, task_id, metadata)

    # 6. initialize the current singleton
    current._set_env(
        flow=self.flow,
        run_id=run_id,
        step_name=step_name,
        task_id=task_id,
        retry_count=retry_count,
        origin_run_id=origin_run_id,
        namespace=resolve_identity(),
        username=get_username(),
        metadata_str=self.metadata.metadata_str(),
        is_running=True,
        tags=self.metadata.sticky_tags,
    )

Key changes from earlier versions:

# metaflow/task.py

    # 7. run task
    output.save_metadata(
        {
            "task_begin": {
                "code_package_metadata": os.environ.get("METAFLOW_CODE_METADATA", ""),
                "code_package_sha": os.environ.get("METAFLOW_CODE_SHA"),
                "code_package_ds": os.environ.get("METAFLOW_CODE_DS"),
                "code_package_url": os.environ.get("METAFLOW_CODE_URL"),
                "retry_count": retry_count,
            }
        }
    )
    start = time.time()
    self.metadata.start_task_heartbeat(self.flow.name, run_id, step_name, task_id)
    with self.monitor.measure("metaflow.task.duration"):
        try:
            with self.monitor.count("metaflow.task.start"):
                _system_logger.log_event(
                    level="info", module="metaflow.task", name="start",
                    payload={**task_payload, "msg": "Task started"},
                )

            self.flow._current_step = step_name
            self.flow._success = False
            self.flow._task_ok = None
            self.flow._exception = None

            if join_type:
                # Join step:
                if join_type != "foreach":
                    split_node = self.flow._graph[node.split_parents[-1]]
                    expected_inputs = len(split_node.out_funcs)
                    if len(inputs) != expected_inputs:
                        raise MetaflowDataMissing(
                            "Join *%s* expected %d inputs but only %d inputs "
                            "were found" % (step_name, expected_inputs, len(inputs))
                        )
                input_obj = Inputs(self._clone_flow(inp) for inp in inputs)
                self.flow._set_datastore(output)
                current._update_env(
                    {
                        "parameter_names": self._init_parameters(inputs[0], passdown=True),
                        "graph_info": self.flow._graph_info,
                    }
                )
            else:
                # Linear step:
                if len(inputs) > 1:
                    raise MetaflowInternalError(
                        "Step *%s* is not a join step but it gets multiple inputs."
                        % step_name
                    )
                self.flow._set_datastore(inputs[0])
                if input_paths:
                    current._update_env(
                        {
                            "parameter_names": self._init_parameters(inputs[0], passdown=False),
                            "graph_info": self.flow._graph_info,
                        }
                    )

            for deco in decorators:
                deco.task_pre_step(
                    step_name, output, self.metadata, run_id, task_id,
                    self.flow, self.flow._graph, retry_count,
                    max_user_code_retries, self.ubf_context, inputs,
                )

            # decorators can decorate or replace the step function entirely
            orig_step_func = step_func
            for deco in decorators:
                step_func = deco.task_decorate(
                    step_func, self.flow, self.flow._graph,
                    retry_count, max_user_code_retries, self.ubf_context,
                )

            if join_type:
                self._exec_step_function(step_func, orig_step_func, input_obj)
            else:
                self._exec_step_function(step_func, orig_step_func)

            for deco in decorators:
                deco.task_post_step(
                    step_name, self.flow, self.flow._graph,
                    retry_count, max_user_code_retries,
                )

            self.flow._task_ok = True
            self.flow._success = True

        except Exception as ex:
            with self.monitor.count("metaflow.task.exception"):
                _system_logger.log_event(
                    level="error", module="metaflow.task", name="exception",
                    payload={**task_payload, "msg": traceback.format_exc()},
                )

            exception_handled = False
            for deco in decorators:
                res = deco.task_exception(
                    ex, step_name, self.flow, self.flow._graph,
                    retry_count, max_user_code_retries,
                )
                exception_handled = bool(res) or exception_handled

            if exception_handled:
                self.flow._task_ok = True
            else:
                self.flow._task_ok = False
                self.flow._exception = MetaflowExceptionWrapper(ex)
                print("%s failed:" % self.flow, file=sys.stderr)
                raise

        finally:
            # ... finalize control task, register metadata, persist flow state ...
            output.save_metadata({"task_end": {}})
            output.persist(self.flow)
            output.done()

            for deco in decorators:
                deco.task_finished(
                    step_name, self.flow, self.flow._graph,
                    self.flow._task_ok, retry_count, max_user_code_retries,
                )
            self.metadata.stop_heartbeat()

Key differences from earlier versions:

  1. In try block:
    1. The task is now wrapped in self.monitor.measure("metaflow.task.duration") and self.monitor.count("metaflow.task.start") context managers for structured metrics collection.
    2. Logging uses _system_logger.log_event instead of the old logger.log(msg) pattern.
    3. Join step input validation now checks against split_node.out_funcs count instead of node.in_funcs, which is more accurate for non-foreach joins.
    4. current._update_env now includes graph_info for DAG information APIs.
    5. orig_step_func is preserved and passed to _exec_step_function alongside the decorated version — this enables the runtime to access the original function even after decorator wrapping.
  2. In case of an exception, structured logging with _system_logger replaces the old dict-based logging. Metrics are tracked via self.monitor.count("metaflow.task.exception").
  3. In finally block, metadata and current flow state are persisted, and each decorator does on-finish task.

Persisting the Flow State

A sneak peek at the code that persists the flow state. We see that all the class and instance variables are persisted except for:

# metaflow/datastore/task_datastore.py

class TaskDataStore(object):
    @only_if_not_done
    @require_mode("w")
    def persist(self, flow):
        """
        Persist any new artifacts that were produced when running flow

        NOTE: This is a DESTRUCTIVE operation that deletes artifacts from
        the given flow to conserve memory. Don't rely on artifact attributes
        of the flow object after calling this function.

        Parameters
        ----------
        flow : FlowSpec
            Flow to persist
        """

        if flow._datastore:
            self._objects.update(flow._datastore._objects)
            self._info.update(flow._datastore._info)

        # we create a list of valid_artifacts in advance, outside of
        # artifacts_iter so we can provide a len_hint below
        valid_artifacts = []
        for var in dir(flow):
            if var.startswith("__") or var in flow._EPHEMERAL:
                continue
            # Skip over properties of the class (Parameters or class variables)
            if hasattr(flow.__class__, var) and isinstance(
                    getattr(flow.__class__, var), property
            ):
                continue

            val = getattr(flow, var)
            if not (
                    isinstance(val, MethodType)
                    or isinstance(val, FunctionType)
                    or isinstance(val, Parameter)
            ):
                valid_artifacts.append((var, val))

        def artifacts_iter():
            # we consume the valid_artifacts list destructively to
            # make sure we don't keep references to artifacts. We
            # want to avoid keeping original artifacts and encoded
            # artifacts in memory simultaneously
            while valid_artifacts:
                var, val = valid_artifacts.pop()
                if not var.startswith("_") and var != "name":
                    # NOTE: Destructive mutation of the flow object. We keep
                    # around artifacts called 'name' and anything starting with
                    # '_' as they are used by the Metaflow runtime.
                    delattr(flow, var)
                yield var, val

        self.save_artifacts(artifacts_iter(), len_hint=len(valid_artifacts))