Skip to content

stepup.core.call

Driver function to facilitate writing scripts that adhere to StepUp’s call protocol.

See Call Protocol for more details.

driver(obj=None)

Implement call protocol.

The most common usage is to call driver() from a script that defines a run() function, e.g.:

#!/usr/bin/env python3
from stepup.core.call import driver

def run(a: int, b: int) -> int:
    return a + b

if __name__ == "__main__":
    driver()

Parameters:

  • obj (Any, default: None ) –

    When not provided, the namespace of the module where driver is defined will be searched for the name ‘run’ to implement the call protocol. When an object is given as a parameter, its attributes are searched instead.

Source code in stepup/core/call.py
def driver(obj: Any = None):
    """Implement call protocol.

    The most common usage is to call `driver()` from a script that defines a `run()` function, e.g.:

    ```python
    #!/usr/bin/env python3
    from stepup.core.call import driver

    def run(a: int, b: int) -> int:
        return a + b

    if __name__ == "__main__":
        driver()
    ```

    Parameters
    ----------
    obj
        When not provided, the namespace of the module where `driver` is defined
        will be searched for the name 'run' to implement the call protocol.
        When an object is given as a parameter, its attributes are searched instead.
    """
    frame = inspect.currentframe().f_back
    script_path = Path(frame.f_locals["__file__"]).relpath()
    if obj is None:
        # Get the calling module and use it as obj
        module_name = frame.f_locals["__name__"]
        obj = sys.modules.get(module_name)
        if obj is None:
            raise ValueError(
                f"The driver must be called from an imported module, got {module_name}"
            )
    args = parse_args(script_path)

    # Load the keyword arguments
    if args.json_inp is not None:
        kwargs = json.loads(args.json_inp)
    elif args.path_inp is None:
        kwargs = {}
    elif args.path_inp.suffix == ".json":
        with open(args.path_inp) as fh:
            kwargs = json.load(fh)
    elif args.path_inp.suffix == ".pickle":
        with open(args.path_inp, "rb") as fh:
            kwargs = pickle.load(fh)
    else:
        raise NotImplementedError(f"Unsupported input file format: {args.path_inp.suffix}")

    # Call the run function
    run = getattr(obj, "run", None)
    if run is None:
        raise AttributeError("The module must define a 'run' function")
    # Filter kwargs to only include those accepted by the run function
    run_signature = inspect.signature(run)
    filtered_kwargs = {k: v for k, v in kwargs.items() if k in run_signature.parameters}
    result = run(**filtered_kwargs)

    # Use a local import because the API is only needed when the driver is called.
    from .api import amend

    # Amend inputs using imported modules.
    # This goes a bit against good practice, in the sense that amending should be done early.
    # It is acceptable here because the driver would fail anyway if the imports are not available.
    # By amending after calling the driver, we also pick up local imports, if any.
    out = []
    if not (result is None or args.path_out is None) and args.amend_out:
        out.append(args.path_out)
    amend(inp=_get_local_import_paths(script_path), out=out)

    # Save the result if not None
    if result is not None:
        if args.path_out is None:
            raise ValueError("The output path is mandatory when the run function returns a value.")
        if args.path_out.suffix == ".json":
            with open(args.path_out, "w") as fh:
                json.dump(result, fh)
                fh.write("\n")
        elif args.path_out.suffix == ".pickle":
            with open(args.path_out, "wb") as fh:
                pickle.dump(result, fh)
        else:
            raise NotImplementedError(f"Unsupported output file format: {args.path_out.suffix}")