Implement call protocol.
The most common usage is to call driver()
from a script that defines a run()
function, e.g.:
#!/usr/bin/env python3
from stepup.core.call import driver
def run(a: int, b: int) -> int:
return a + b
if __name__ == "__main__":
driver()
Parameters:
-
obj
(Any
, default:
None
)
–
When not provided, the namespace of the module where driver
is defined
will be searched for the name ‘run’ to implement the call protocol.
When an object is given as a parameter, its attributes are searched instead.
Source code in stepup/core/call.py
| def driver(obj: Any = None):
"""Implement call protocol.
The most common usage is to call `driver()` from a script that defines a `run()` function, e.g.:
```python
#!/usr/bin/env python3
from stepup.core.call import driver
def run(a: int, b: int) -> int:
return a + b
if __name__ == "__main__":
driver()
```
Parameters
----------
obj
When not provided, the namespace of the module where `driver` is defined
will be searched for the name 'run' to implement the call protocol.
When an object is given as a parameter, its attributes are searched instead.
"""
frame = inspect.currentframe().f_back
script_path = Path(frame.f_locals["__file__"]).relpath()
if obj is None:
# Get the calling module and use it as obj
module_name = frame.f_locals["__name__"]
obj = sys.modules.get(module_name)
if obj is None:
raise ValueError(
f"The driver must be called from an imported module, got {module_name}"
)
args = parse_args(script_path)
# Load the keyword arguments
if args.json_inp is not None:
kwargs = json.loads(args.json_inp)
elif args.path_inp is None:
kwargs = {}
elif args.path_inp.suffix == ".json":
with open(args.path_inp) as fh:
kwargs = json.load(fh)
elif args.path_inp.suffix == ".pickle":
with open(args.path_inp, "rb") as fh:
kwargs = pickle.load(fh)
else:
raise NotImplementedError(f"Unsupported input file format: {args.path_inp.suffix}")
# Call the run function
run = getattr(obj, "run", None)
if run is None:
raise AttributeError("The module must define a 'run' function")
# Filter kwargs to only include those accepted by the run function
run_signature = inspect.signature(run)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in run_signature.parameters}
result = run(**filtered_kwargs)
# Use a local import because the API is only needed when the driver is called.
from .api import amend
# Amend inputs using imported modules.
# This goes a bit against good practice, in the sense that amending should be done early.
# It is acceptable here because the driver would fail anyway if the imports are not available.
# By amending after calling the driver, we also pick up local imports, if any.
out = []
if not (result is None or args.path_out is None) and args.amend_out:
out.append(args.path_out)
amend(inp=_get_local_import_paths(script_path), out=out)
# Save the result if not None
if result is not None:
if args.path_out is None:
raise ValueError("The output path is mandatory when the run function returns a value.")
if args.path_out.suffix == ".json":
with open(args.path_out, "w") as fh:
json.dump(result, fh)
fh.write("\n")
elif args.path_out.suffix == ".pickle":
with open(args.path_out, "wb") as fh:
pickle.dump(result, fh)
else:
raise NotImplementedError(f"Unsupported output file format: {args.path_out.suffix}")
|