Skip to content

stemflow.utils.wrapper


Wrapping the model. For example, monkey patching.

model_wrapper(model)

wrap a predict_proba function for those models who don't have

Parameters:

  • model (BaseEstimator) –

    Input model

Returns:

  • BaseEstimator

    Wrapped model that has a predict_proba method (BaseEstimator)

Source code in stemflow/utils/wrapper.py
def model_wrapper(model: BaseEstimator) -> BaseEstimator:
    """wrap a predict_proba function for those models who don't have

    Args:
        model (BaseEstimator):
            Input model

    Returns:
        Wrapped model that has a `predict_proba` method (BaseEstimator)

    """
    if "predict_proba" in dir(model):
        return model
    else:
        warnings.warn("predict_proba function not in base_model. Monkey patching one.")

        model.predict_proba = _monkey_patched_predict_proba
        return model