Source code for fiddle._src.building

# coding=utf-8
# Copyright 2022 The Fiddle-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements Fiddle's build() function."""

import contextlib
import functools
import logging
import threading
from typing import Any, Callable, Dict, Sequence, TypeVar, Union, overload

from fiddle._src import config as config_lib
from fiddle._src import daglish
from fiddle._src import partial
from fiddle._src import reraised_exception

T = TypeVar('T')


class _BuildGuardState(threading.local):

  def __init__(self):
    super().__init__()
    self.in_build = False


_state = _BuildGuardState()


@contextlib.contextmanager
def _in_build():
  """A context manager to ensure fdl.build is not called recursively."""
  if _state.in_build:
    raise ValueError(
        'It is forbidden to call `fdl.build` inside another `fdl.build` call.')
  _state.in_build = True
  try:
    yield
  finally:
    _state.in_build = False


def _format_arg(arg: Any) -> str:
  """Returns repr(arg), returning a constant string if repr() fails."""
  try:
    return repr(arg)
  except Exception:  # pylint: disable=broad-except
    return f'<ERROR FORMATTING {type(arg)} ARGUMENT>'


def _make_message(
    current_path: daglish.Path,
    buildable: config_lib.Buildable,
    args: Sequence[Any],
    kwargs: Dict[str, Any],
) -> str:
  """Returns Fiddle-related debugging information for an exception."""
  path_str = '<root>' + daglish.path_str(current_path)
  fn_or_cls = config_lib.get_callable(buildable)
  try:
    fn_or_cls_name = fn_or_cls.__qualname__
  except AttributeError:
    fn_or_cls_name = str(fn_or_cls)  # callable instances, etc.
  args_str = ', '.join(f'{_format_arg(value)}' for value in args)
  kwargs_str = ', '.join(
      f'{name}={_format_arg(value)}' for name, value in kwargs.items()
  )

  tag_information = ''
  bound_args = buildable.__signature_info__.signature.bind_partial(
      *args, **kwargs
  )
  bound_args.apply_defaults()
  unset_arg_tags = []
  for param in buildable.__signature_info__.parameters:
    if param in bound_args.arguments:
      continue  # User supplied it, all good.
    tags = buildable.__argument_tags__.get(param, None)
    if tags:
      tag_str = ' '.join(sorted(str(tag) for tag in tags))
      unset_arg_tags.append(f' - {param}: {tag_str}')
  if unset_arg_tags:
    tag_details = '\n'.join(unset_arg_tags)
    tag_information = f'\nTags for unset arguments:\n{tag_details}'

  return (
      '\n\nFiddle context: failed to construct or call '
      f'{fn_or_cls_name} at {path_str} '
      f'with positional arguments: ({args_str}), '
      f'keyword arguments: ({kwargs_str}){tag_information}.'
  )


def call_buildable(
    buildable: config_lib.Buildable,
    arguments: Dict[Union[str, int], Any],
    *,
    current_path: daglish.Path,
) -> Any:
  """Prepare positional arguments and actually build the buildable."""
  args, kwargs = buildable.__signature_info__.transform_to_args_kwargs(
      arguments
  )
  make_message = functools.partial(
      _make_message, current_path, buildable, args, kwargs
  )
  with reraised_exception.try_with_lazy_message(make_message):
    return buildable.__build__(*args, **kwargs)


# Define typing overload for `build(Partial[T])`
@overload
def build(buildable: partial.Partial[T]) -> Callable[..., T]:
  ...


# Define typing overload for `build(Config[T])`
@overload
def build(buildable: config_lib.Config[T]) -> T:
  ...


# Define typing overload for nested structures.
@overload
def build(buildable: Any) -> Any:
  ...


# This is a free function instead of a method on the `Buildable` object in order
# to avoid potential naming collisions (e.g., if a function or class has a
# parameter named `build`).
[docs] def build(buildable): """Builds ``buildable``, recursively building nested ``Buildable`` instances. This is the core function for turning a ``Buildable`` into a usable object. It recursively walks through ``buildable``'s parameters, building any nested ``Config`` instances. Depending on the specific ``Buildable`` type passed (``Config`` or ``Partial``), the result is either the result of calling ``config.__fn_or_cls__`` with the configured parameters, or a partial function or class with those parameters bound. If the same ``Buildable`` instance is seen multiple times during traversal of the configuration tree, ``build`` is called only once (for the first instance encountered), and the result is reused for subsequent copies of the instance. This is achieved via the ``memo`` dictionary (similar to ``deepcopy``). This has the effect that for configured class instances, each separate config instance is in one-to-one correspondence with an actual instance of the configured class after calling ``build`` (shared config instances <=> shared class instances). Args: buildable: A ``Buildable`` instance to build, or a nested structure of ``Buildable`` objects. Returns: The built version of ``buildable``. """ is_built = False def _build(value: Any, state: daglish.State) -> Any: """Inner method / implementation of build().""" nonlocal is_built if isinstance(value, config_lib.Buildable): sub_traversal = state.flattened_map_children(value) metadata: config_lib.BuildableTraverserMetadata = sub_traversal.metadata arguments = metadata.arguments(sub_traversal.values) is_built = True return call_buildable(value, arguments, current_path=state.current_path) else: return state.map_children(value) with _in_build(): result = daglish.MemoizedTraversal.run(_build, buildable) if not is_built: logging.warning( 'No Buildables found in value passed to `fdl.build()`: ' '%s with type %s.', str(buildable), type(buildable), ) return result