import { Layer } from './layer';
import {
    GlorotUniform,
    Initializer,
    InitializerType,
    Orthogonal,
    Zeros,
} from './tf-initializers';
import {
    ActivationType,
    RNNCell,
    TfDataFormat,
    TfFillMode,
    TfFlipMode,
    TfInterpolation,
    TfMergeMode,
    TfPadding,
} from './tf-types';

type Constructor<T> = new (...args: any[]) => T;

export function composeInitializer(mixins) {
    return mixins.reduce(
        (accumulatedType, mixin) => mixin(accumulatedType),
        class {}
    );
}

export function composeLayer(mixins) {
    return mixins.reduce(
        (accumulatedType, mixin) => mixin(accumulatedType),
        Layer
    );
}

export function WithDataFormat<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        data_format: TfDataFormat = 'channels_last';
    };
}

export function WithKeepDims<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        keepdims = false;
    };
}

export function WithPadding<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        padding: TfPadding = 'valid';
    };
}

export function WithPadding1D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        padding = 1;
    };
}

export function WithPadding2D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        padding = [1, 1];
    };
}

export function WithPadding3D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        padding = [1, 1, 1];
    };
}

export function WithStrides1D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        strides = 1;
    };
}

export function WithStrides2D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        strides = [1, 1];
    };
}

export function WithStrides3D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        strides = [1, 1, 1];
    };
}

export function WithFilters<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        filters = 16;
    };
}

export function WithKernelSize<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        kernel_size = 3;
    };
}

export function WithDepthwise<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        depth_multiplier = 1;
        depthwise_initializer: Initializer = {
            class_name: InitializerType.GLOROTUNIFORM,
            config: GlorotUniform,
        };
        depthwise_regularizer = null;
        depthwise_constraint = null;
    };
}

export function WithPointwise<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        pointwise_initializer: Initializer = {
            class_name: InitializerType.GLOROTUNIFORM,
            config: GlorotUniform,
        };
        pointwise_regularizer = null;
        pointwise_constraint = null;
    };
}

// ugliest thing in the world, select tf implementation
export function WithImplementation<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        implementation = 1;
    };
}

export function WithPoolSize1D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        pool_size = 2;
    };
}

export function WithPoolSize2D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        pool_size = [2, 2];
    };
}

export function WithPoolSize3D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        pool_size = [2, 2, 2];
    };
}

export function WithRate<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        rate: number;
    };
}

export function WithFlipMode<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        mode: TfFlipMode = 'horizontal_and_vertical';
    };
}

export function WithSeed<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        seed: number;
    };
}

export function WithInterpolation<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        interpolation: TfInterpolation = 'bilinear';
    };
}

export function WithHeightWidth<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        height = 0;
        width = 0;
    };
}

export function WithFactor<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        factor = [1, 1];
    };
}

export function WithMergeMode<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        merge_mode: TfMergeMode = 'concat';
    };
}

export function WithHeightWidthFactor<T extends Constructor<any>>(
    superClass: T
) {
    return class extends superClass {
        height_factor = [1, 1];
        width_factor = [1, 1];
    };
}

export function WithKernel<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        kernel_regularizer = null;
        kernel_constraint = null;
        kernel_initializer: Initializer = {
            class_name: InitializerType.GLOROTUNIFORM,
            config: GlorotUniform,
        };
    };
}
export function WithBias<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        bias_regularizer = null;
        bias_constraint = null;
        bias_initializer: Initializer = {
            class_name: InitializerType.ZEROS,
            config: Zeros,
        };
    };
}

export function WithRecurrent<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        recurrent_constraint = null;
        recurrent_regularizer = null;
        recurrent_initializer: Initializer = {
            class_name: InitializerType.ORTHOGONAL,
            config: Orthogonal,
        };
    };
}

export function WithNoiseShape<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        noise_shape: number[] = null;
    };
}

export function WithActivationLinear<T extends Constructor<any>>(
    superClass: T
) {
    return class extends superClass {
        activation: ActivationType = ActivationType.LINEAR;
    };
}
export function WithActivationTanh<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        activation: ActivationType = ActivationType.TANH;
    };
}

export function WithDilationRate1D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        dilation_rate = 1;
    };
}
export function WithDilationRate2D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        dilation_rate = [1, 1];
    };
}
export function WithDilationRate3D<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        dilation_rate = [1, 1, 1];
    };
}
export function WithOutputPadding<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        output_padding: number = null;
    };
}
export function WithUseBias<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        use_bias = true;
    };
}
export function WithAxis<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        axis = -1;
    };
}
export function WithUseScale<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        use_scale = true;
    };
}
export function WithMean<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        mean = 0.0;
    };
}
export function WithGain<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        gain = 1.0;
    };
}
export function WithStddev<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        stddev = 1.0;
    };
}
export function WithSparse<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        sparse = false;
    };
}

export function WithActivityRegularizer<T extends Constructor<any>>(
    superClass: T
) {
    return class extends superClass {
        activity_regularizer = null;
    };
}

export function WithUnits<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        units = 64;
    };
}

export function WithDropoutRate<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        dropout = 0.0;
    };
}

export function WithRecurrentDropout<T extends Constructor<any>>(
    superClass: T
) {
    return class extends superClass {
        recurrent_dropout = 0.0;
    };
}

export function WithFillMode<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        fill_mode: TfFillMode = 'reflect';
        fill_value = 0.0;
    };
}

export function WithGroups<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        groups = 1;
    };
}

export function WithRNN<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        cell: RNNCell | RNNCell[];
        return_sequences = false;
        return_state = false;
        go_backwards = false;
        stateful = false;
        unroll = false;
        time_major = false;
        zero_output_for_mask = false;
    };
}

export function WithWrapper<T extends Constructor<any>>(superClass: T) {
    return class extends superClass {
        layer: Layer = null;
    };
}
