/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.sgd;

import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

public abstract class AbstractSGDModel<T extends Output<T>>
extends Model<T> {
    private static final long serialVersionUID = 1L;
    protected FeedForwardParameters modelParameters;
    protected boolean addBias = true;

    protected AbstractSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FeedForwardParameters weights, boolean generatesProbabilities, boolean addBias) {
        super(name, provenance, featureIDMap, outputIDInfo, generatesProbabilities);
        this.modelParameters = weights;
        this.addBias = addBias;
    }

    protected PredAndActive predictSingle(Example<T> example) {
        int minNumFeatures;
        Object features = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)this.addBias) : SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)this.addBias);
        int n = minNumFeatures = this.addBias ? 1 : 0;
        if (features.numActiveElements() == minNumFeatures) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        return new PredAndActive(this.modelParameters.predict((SGDVector)features), features.numActiveElements());
    }

    public FeedForwardParameters getModelParameters() {
        return this.modelParameters.copy();
    }

    protected static final class PredAndActive {
        public final DenseVector prediction;
        public final int numActiveFeatures;

        PredAndActive(DenseVector prediction, int numActiveFeatures) {
            this.prediction = prediction;
            this.numActiveFeatures = numActiveFeatures;
        }
    }
}

