/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RFactorVector;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RRaw;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVector;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.XGBoostUtil;

public class XGBoostConverter
extends ModelConverter<RGenericVector> {
    private Learner learner = null;
    private FeatureMap featureMap = null;
    private boolean compact = this.getOption("compact", Boolean.TRUE);

    public XGBoostConverter(RGenericVector booster) {
        super(booster);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RVector<?> missing;
        RGenericVector booster = (RGenericVector)this.getObject();
        RStringVector featureNames = booster.getStringElement("feature_names", false);
        RGenericVector schema = booster.getGenericElement("schema", false);
        FeatureMap featureMap = this.ensureFeatureMap();
        if (featureNames != null) {
            XGBoostConverter.checkFeatureMap(featureMap, featureNames);
        }
        if (schema != null && (missing = schema.getVectorElement("missing", false)) != null) {
            featureMap.addMissingValue(ValueUtil.asString(missing.asScalar()));
        }
        Learner learner = this.ensureLearner();
        ObjFunction obj = learner.obj();
        FieldName targetField = FieldName.create((String)"_target");
        List<String> targetCategories = null;
        if (schema != null) {
            RStringVector responseName = schema.getStringElement("response_name", false);
            RStringVector responseLevels = schema.getStringElement("response_levels", false);
            if (responseName != null) {
                targetField = FieldName.create((String)((String)responseName.asScalar()));
            }
            if (responseLevels != null) {
                targetCategories = responseLevels.getValues();
            }
        }
        Label label = obj.encodeLabel(targetField, targetCategories, (PMMLEncoder)encoder);
        encoder.setLabel(label);
        List features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        for (Feature feature : features) {
            encoder.addFeature(feature);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector booster = (RGenericVector)this.getObject();
        RNumberVector<?> ntreeLimit = booster.getNumericElement("ntreelimit", false);
        Learner learner = this.ensureLearner();
        LinkedHashMap<String, Comparable<Boolean>> options = new LinkedHashMap<String, Comparable<Boolean>>();
        options.put("compact", Boolean.valueOf(this.compact));
        options.put("ntree_limit", ntreeLimit != null ? ValueUtil.asInteger((Number)((Number)ntreeLimit.asScalar())) : null);
        Schema xgbSchema = learner.toXGBoostSchema(schema);
        MiningModel miningModel = learner.encodeMiningModel(options, xgbSchema);
        return miningModel;
    }

    @Override
    protected Map<VerificationField, List<?>> encodeActiveValues(RGenericVector dataFrame) {
        FeatureMap featureMap = this.ensureFeatureMap();
        XGBoostConverter.checkFeatureMap(featureMap, dataFrame);
        List entries = featureMap.getEntries();
        LinkedHashMap data = new LinkedHashMap();
        block4: for (int i = 0; i < dataFrame.size(); ++i) {
            FeatureMap.Entry entry = (FeatureMap.Entry)entries.get(i);
            final RVector<?> column = dataFrame.getVectorValue(i);
            FieldName name = FieldName.create((String)entry.getName());
            String value = entry.getValue();
            FeatureMap.Entry.Type type = entry.getType();
            switch (type) {
                case BINARY_INDICATOR: {
                    RFactorVector factorColumn = (RFactorVector)data.get(name);
                    if (factorColumn == null) {
                        factorColumn = new RFactorVector(null, null){
                            private List<String> factorValues;
                            {
                                super(values, attributes);
                                this.factorValues = new ArrayList<String>();
                                for (int i = 0; i < column.size(); ++i) {
                                    this.factorValues.add(null);
                                }
                            }

                            @Override
                            public List<String> getFactorValues() {
                                return this.factorValues;
                            }
                        };
                        data.put(name, factorColumn);
                    }
                    List<String> factorValues = factorColumn.getFactorValues();
                    List<?> mask = column.getValues();
                    for (int row = 0; row < mask.size(); ++row) {
                        Number rowMask = (Number)mask.get(row);
                        if (rowMask == null || rowMask.doubleValue() != 1.0) continue;
                        factorValues.set(row, value);
                    }
                    continue block4;
                }
                case FLOAT: 
                case INTEGER: {
                    data.put(name, column);
                    continue block4;
                }
                default: {
                    throw new IllegalArgumentException(String.valueOf(type));
                }
            }
        }
        ArrayList columns = new ArrayList(data.values());
        ArrayList<FieldName> names = new ArrayList<FieldName>(data.keySet());
        return XGBoostConverter.encodeVerificationData(columns, names);
    }

    private FeatureMap ensureFeatureMap() {
        if (this.featureMap == null) {
            this.featureMap = this.loadFeatureMap();
        }
        return this.featureMap;
    }

    private Learner ensureLearner() {
        if (this.learner == null) {
            this.learner = this.loadLearner();
        }
        return this.learner;
    }

    private FeatureMap loadFeatureMap() {
        RGenericVector booster = (RGenericVector)this.getObject();
        RVector<?> fmap = DecorationUtil.getVectorElement(booster, "fmap");
        try {
            return XGBoostConverter.loadFeatureMap(fmap);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
    }

    private Learner loadLearner() {
        RGenericVector booster = (RGenericVector)this.getObject();
        RRaw raw = (RRaw)booster.getElement("raw");
        try {
            return XGBoostConverter.loadLearner(raw);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
    }

    private static void checkFeatureMap(FeatureMap featureMap, RVector<?> vector) {
        List entries = featureMap.getEntries();
        if (vector.size() != entries.size()) {
            throw new IllegalArgumentException("Invalid 'fmap' element. Expected " + vector.size() + " features, got " + entries.size() + " features");
        }
    }

    private static FeatureMap loadFeatureMap(RVector<?> fmap) throws IOException {
        if (fmap instanceof RStringVector) {
            return XGBoostConverter.loadFeatureMap((RStringVector)fmap);
        }
        if (fmap instanceof RGenericVector) {
            return XGBoostConverter.loadFeatureMap((RGenericVector)fmap);
        }
        throw new IllegalArgumentException();
    }

    private static FeatureMap loadFeatureMap(RStringVector fmap) throws IOException {
        File file = new File((String)fmap.asScalar());
        try (FileInputStream is = new FileInputStream(file);){
            FeatureMap featureMap = XGBoostUtil.loadFeatureMap((InputStream)is);
            return featureMap;
        }
    }

    private static FeatureMap loadFeatureMap(RGenericVector fmap) {
        RIntegerVector id = fmap.getIntegerValue(0);
        RFactorVector name = fmap.getFactorValue(1);
        RFactorVector type = fmap.getFactorValue(2);
        FeatureMap featureMap = new FeatureMap();
        for (int i = 0; i < id.size(); ++i) {
            if (i != id.getValue(i)) {
                throw new IllegalArgumentException();
            }
            featureMap.addEntry(name.getFactorValue(i), type.getFactorValue(i));
        }
        return featureMap;
    }

    private static Learner loadLearner(RRaw raw) throws IOException {
        byte[] value = raw.getValue();
        try (ByteArrayInputStream is = new ByteArrayInputStream(value);){
            Learner learner = XGBoostUtil.loadLearner((InputStream)is, (ByteOrder)ByteOrder.nativeOrder(), null, (String)"$.Model");
            return learner;
        }
    }
}

