BaseEstimator类用来处理输入数据的格式

类内的全局变量有\[X,y,y\_{required},fit_{required}\]

\[X,y\]通过__setup_input()方法将\[X,y\]变为numpy.ndarray类型

如果输入数据没有\[y\], 则\[y\_{required} = False\]

处理输入数据\[X\]的步骤如下:

  1. 若X不是numpy.ndarray类型,则转换类型;

  2. 若X为空数组,则提示值错误;

  3. 若X的维数为1,则X的样本数为1,X的特征数目为X.shape[0];

  4. 若X的维数不为1,则X的样本数为X.shape[0],X的特征数目为其他维数的长度之积。

如果输入了\[y\],则对\[y\]处理的步骤如下:

  1. 若需要输入y却没有输入y,则提示错误;

  2. 若X不是numpy.ndarray类型,则转换类型;

  3. 若输入了y,但是大小为0,则提示错误。


# coding: utf-8
import numpy as np


class BaseEstimator(Object):
    X = None
    y = None
    y_required = True
    fit_required = True
    def __setup_input(self,X,y=None):
        if not isinstance(X,np.ndarray):
            X = np.array(X)
        
        if X.size == 0:
            raise ValueError('Number of feautures must be > 0 ')
        
        if X.ndim == 1:
            self.n_samples, self.n_feautures = 1, X.shape
        else:
            self.n_samples, self.n_feautures = X.shape[0], np.prod(X.shape[1:])

        self.X = X

        if self.y_required:
            if y is None:
                raise ValueError('Missed required argument y ')

            if not isinstance(y,np.ndarray):
                y = np.array(y)

            if y.size == 0:
                raise ValueError('Number of target y must be > 0')

        self.y = y

    def fit(self,X,y= None):
        self.__setup_input(X,y)

    def predict(self,X=None):
        if not isinstance(X,np.ndarray):
            X = np.array(X)

        if X is not None or not fit_required:
            return self._predict(X)
        else:
            raise ValueError('You must call fit before predict')

    def _predict(X=None):
        raise NotImplementedError()

版权声明:本文为shq-lj原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.cnblogs.com/shq-lj/p/11845048.html