/*
 * Decompiled with CFR 0.152.
 */
package org.ohdsi.metaAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.hmc.CompoundDerivative;
import dr.inference.hmc.CompoundGradient;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Bounds;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.DesignMatrix;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.OperatorSchedule;
import dr.inference.operators.RandomWalkOperator;
import dr.inference.operators.ScaleOperator;
import dr.inference.operators.SimpleOperatorSchedule;
import dr.math.MathUtils;
import dr.math.distributions.Distribution;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.NormalDistribution;
import dr.util.Attribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ohdsi.likelihood.CachedModelLikelihood;
import org.ohdsi.mcmc.Analysis;
import org.ohdsi.mcmc.Runner;
import org.ohdsi.metaAnalysis.DataModel;
import org.ohdsi.metaAnalysis.ExtendingEmpiricalDataModel;
import org.ohdsi.metaAnalysis.GammaOnPrecisionPrior;
import org.ohdsi.metaAnalysis.ScalePrior;
import org.ohdsi.simpleDesign.SimpleLinearModel;

public class HierarchicalMetaAnalysis
implements Analysis {
    private final Likelihood likelihood;
    private final Likelihood prior;
    private final Likelihood joint;
    private final List<Parameter> parameters;
    private final OperatorSchedule schedule;

    private ScalePrior makeHierarchicalScalePrior(double hyper1, double hyper2) {
        return new GammaOnPrecisionPrior(hyper1, hyper2);
    }

    public HierarchicalMetaAnalysis(List<DataModel> allMetaAnalysisDataModels, HierarchicalMetaAnalysisConfiguration cg) {
        MathUtils.setSeed((long)cg.seed);
        ArrayList<CachedModelLikelihood> allDataLikelihoods = new ArrayList<CachedModelLikelihood>();
        ArrayList<Object> allOperators = new ArrayList<Object>();
        ArrayList<Parameter> allParameters = new ArrayList<Parameter>();
        int metaAnalysisCount = 0;
        for (DataModel metaAnalysis : allMetaAnalysisDataModels) {
            int maLabel = metaAnalysisCount + 1;
            allDataLikelihoods.add(new CachedModelLikelihood("likelihood" + maLabel, metaAnalysis));
            Parameter beta = metaAnalysis.getCompoundParameter();
            allParameters.add(beta);
            allOperators.add(new RandomWalkOperator(beta, null, 0.75, RandomWalkOperator.BoundaryCondition.reflecting, cg.operatorWeight * (double)beta.getDimension(), cg.mode));
            ++metaAnalysisCount;
        }
        CompoundParameter allBetas = new CompoundParameter("all.beta");
        for (Parameter beta : allParameters) {
            allBetas.addParameter(beta);
        }
        DesignMatrix designMatrix = new DesignMatrix("designMatrix", false);
        CompoundParameter allEffects = new CompoundParameter("allEffects");
        ArrayList<Object> allPriors = new ArrayList<Object>();
        ArrayList<Object> allEffectsGradient = new ArrayList<Object>();
        Parameter.Default tau = new Parameter.Default("tau", cg.startingTau, 0.0, Double.POSITIVE_INFINITY);
        DistributionLikelihood tauPrior = new DistributionLikelihood((Distribution)new GammaDistribution(cg.tauShape, cg.tauScale));
        tauPrior.addData((Attribute)tau);
        ScaleOperator tauOperator = new ScaleOperator((Variable)tau, 0.75, cg.mode, cg.operatorWeight);
        allPriors.add(tauPrior);
        allParameters.add((Parameter)tau);
        allOperators.add(tauOperator);
        int primaryCount = this.addPrimaryDesign(designMatrix, allMetaAnalysisDataModels, cg.primaryEffectName, cg.separateEffectPrior, cg.effectCount);
        Parameter primaryEffect = HierarchicalMetaAnalysis.randomize(cg.primaryEffectName, primaryCount, 0.0, 1.0);
        allEffects.addParameter(primaryEffect);
        HierarchicalNormalComponents primaryComponents = HierarchicalMetaAnalysis.makeHierarchicalNormalComponents(cg.primaryEffectName, primaryEffect, cg.hierarchicalLocationPrimaryHyperStdDev, this.makeHierarchicalScalePrior(cg.gammaHyperPrimaryShape, cg.gammaHyperPrimaryScale), cg.operatorWeight, cg.mode);
        allParameters.add(primaryEffect);
        allParameters.addAll(primaryComponents.parameters);
        allOperators.addAll(primaryComponents.operators);
        allPriors.addAll(primaryComponents.likelihoods);
        allEffectsGradient.addAll(primaryComponents.gradients);
        if (cg.includeSecondary) {
            int secondaryCount = this.addSecondaryDesign(designMatrix, allMetaAnalysisDataModels, cg.secondaryEffectName, cg.separateEffectPrior, cg.effectCount);
            Parameter secondaryEffect = HierarchicalMetaAnalysis.randomize(cg.secondaryEffectName, secondaryCount, 0.0, 1.0);
            allEffects.addParameter(secondaryEffect);
            HierarchicalNormalComponents secondaryComponents = HierarchicalMetaAnalysis.makeHierarchicalNormalComponents(cg.secondaryEffectName, secondaryEffect, cg.hierarchicalLocationSecondaryHyperStdDev, this.makeHierarchicalScalePrior(cg.gammaHyperSecondaryShape, cg.gammaHyperSecondaryScale), cg.operatorWeight, cg.mode);
            allParameters.add(secondaryEffect);
            allParameters.addAll(secondaryComponents.parameters);
            allOperators.addAll(secondaryComponents.operators);
            allPriors.addAll(secondaryComponents.likelihoods);
            allEffectsGradient.addAll(secondaryComponents.gradients);
        }
        if (cg.includeExposure) {
            int effectCount = this.addEffectDesign(designMatrix, allMetaAnalysisDataModels, cg.exposureEffectName, cg.effectCount);
            int i = 0;
            while (i < effectCount) {
                Parameter exposureEffect = HierarchicalMetaAnalysis.randomize(String.valueOf(cg.exposureEffectName) + (i + 1), 1, 0.0, 1.0);
                allEffects.addParameter(exposureEffect);
                DistributionLikelihood exposureDistribution = new DistributionLikelihood((Distribution)new NormalDistribution(cg.exposureHyperLocation.get(cg.exposureHyperLocation.size() > 1 ? i : 0).doubleValue(), cg.exposureHyperStdDev.get(cg.exposureHyperStdDev.size() > 1 ? i : 0).doubleValue()));
                exposureDistribution.addData((Attribute)exposureEffect);
                RandomWalkOperator exposureOperator = new RandomWalkOperator(exposureEffect, null, 0.75, RandomWalkOperator.BoundaryCondition.reflecting, cg.operatorWeight, cg.mode);
                allParameters.add(exposureEffect);
                allOperators.add(exposureOperator);
                allPriors.add(exposureDistribution);
                if (exposureDistribution.getDistribution() instanceof GradientProvider) {
                    allEffectsGradient.add(new GradientWrtParameterProvider.ParameterWrapper((GradientProvider)exposureDistribution.getDistribution(), exposureEffect, (Likelihood)exposureDistribution));
                }
                ++i;
            }
        }
        if (designMatrix.getColumnDimension() != allEffects.getDimension()) {
            throw new RuntimeException("Invalid parameter dimensions");
        }
        SimpleLinearModel allEffectDistribution = new SimpleLinearModel("linearModel", (Parameter)allBetas, designMatrix, (Parameter)allEffects, (Parameter)tau);
        allPriors.add((Object)allEffectDistribution);
        this.prior = new CompoundLikelihood(allPriors);
        this.likelihood = new CompoundLikelihood(allDataLikelihoods);
        this.joint = new CompoundLikelihood(Arrays.asList(this.likelihood, this.prior));
        this.joint.setId("joint");
        this.parameters = allParameters;
        this.schedule = new SimpleOperatorSchedule(1000, 0.0);
        this.schedule.addOperators(allOperators);
    }

    @Override
    public List<Loggable> getLoggerColumns() {
        ArrayList<Loggable> columns = new ArrayList<Loggable>();
        columns.add((Loggable)this.likelihood);
        columns.add((Loggable)this.prior);
        columns.addAll(this.parameters);
        return columns;
    }

    @Override
    public Likelihood getJoint() {
        return this.joint;
    }

    @Override
    public OperatorSchedule getSchedule() {
        return this.schedule;
    }

    public static Parameter randomize(String name, int dim, double center, double scale) {
        double[] effect = new double[dim];
        int i = 0;
        while (i < dim) {
            effect[i] = center + scale * MathUtils.nextGaussian();
            ++i;
        }
        Parameter.Default parameter = new Parameter.Default(name, effect);
        parameter.addBounds((Bounds)new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, dim));
        return parameter;
    }

    public static HierarchicalNormalComponents makeHierarchicalNormalComponents(String name, Parameter effects, double locationHyperStdDev, ScalePrior scalePrior, double weight, AdaptationMode mode) {
        Parameter mean = HierarchicalMetaAnalysis.randomize(String.valueOf(name) + ".mean", 1, 0.0, 1.0);
        Parameter scale = scalePrior.getParameter();
        scale.setId(String.valueOf(name) + ".scale");
        DistributionLikelihood distribution = new DistributionLikelihood((ParametricDistributionModel)new NormalDistributionModel(mean, scale, scalePrior.isPrecision()));
        distribution.addData((Attribute)effects);
        DistributionLikelihood meanHyperDistribution = new DistributionLikelihood((Distribution)new NormalDistribution(0.0, locationHyperStdDev));
        meanHyperDistribution.addData((Attribute)mean);
        ArrayList<Parameter> parameters = new ArrayList<Parameter>();
        parameters.add(mean);
        parameters.add(scale);
        ArrayList<MCMCOperator> operators = new ArrayList<MCMCOperator>();
        operators.add((MCMCOperator)new RandomWalkOperator(effects, null, 0.75, RandomWalkOperator.BoundaryCondition.reflecting, weight * (double)effects.getDimension(), mode));
        operators.add((MCMCOperator)new RandomWalkOperator(mean, null, 0.75, RandomWalkOperator.BoundaryCondition.reflecting, weight, mode));
        operators.add(scalePrior.getOperator(distribution, weight, mode));
        ArrayList<Likelihood> likelihood = new ArrayList<Likelihood>();
        likelihood.add((Likelihood)distribution);
        likelihood.add((Likelihood)meanHyperDistribution);
        likelihood.add(scalePrior.getPrior());
        ArrayList<GradientWrtParameterProvider> gradients = new ArrayList<GradientWrtParameterProvider>();
        if (distribution.getDistribution() instanceof GradientProvider) {
            gradients.add((GradientWrtParameterProvider)new GradientWrtParameterProvider.ParameterWrapper((GradientProvider)distribution.getDistribution(), effects, (Likelihood)distribution));
        }
        return new HierarchicalNormalComponents(parameters, operators, likelihood, gradients);
    }

    private int addPrimaryDesign(DesignMatrix designMatrix, List<DataModel> dataModels, String effectName, boolean separateEffectPrior, int effectCount) {
        int totalLength;
        int effectOffset = totalLength = this.getTotalNumberOfProfiles(dataModels);
        if (separateEffectPrior) {
            effectOffset -= this.getEffectDimensions(dataModels, effectCount);
        }
        int offset = 0;
        int label = 0;
        for (DataModel dataModel : dataModels) {
            int length = dataModel.getCompoundParameter().getDimension();
            double[] effect = new double[totalLength];
            if (offset < effectOffset) {
                int i = 0;
                while (i < length) {
                    effect[offset + i] = 1.0;
                    ++i;
                }
            }
            designMatrix.addParameter((Parameter)new Parameter.Default("dm." + effectName + (label + 1), effect));
            ++label;
            offset += length;
        }
        return label;
    }

    private int addSecondaryDesign(DesignMatrix designMatrix, List<DataModel> dataModels, String effectName, boolean separateEffectPrior, int effectCount) {
        int totalLength = this.getTotalNumberOfProfiles(dataModels);
        int maxIdentifier = this.getMaxIdentifier(dataModels);
        int effectOffset = totalLength;
        if (separateEffectPrior) {
            effectOffset -= this.getEffectDimensions(dataModels, effectCount);
        }
        int id = 0;
        while (id < maxIdentifier) {
            double[] effect = new double[totalLength];
            int offset = 0;
            for (DataModel dataModel : dataModels) {
                if (offset >= effectOffset) continue;
                int length = dataModel.getCompoundParameter().getDimension();
                int whichIndex = this.findIdentifier(dataModel, id + 1);
                if (whichIndex >= 0) {
                    effect[offset + whichIndex] = 1.0;
                }
                offset += length;
            }
            designMatrix.addParameter((Parameter)new Parameter.Default("dm." + effectName + (id + 1), effect));
            ++id;
        }
        return maxIdentifier;
    }

    public int addEffectDesign(DesignMatrix designMatrix, List<DataModel> dataModels, String effectName, int effectCount) {
        int totalLength = this.getTotalNumberOfProfiles(dataModels);
        int effectDimensions = this.getEffectDimensions(dataModels, effectCount);
        int offset = totalLength - effectDimensions;
        int label = 0;
        int k = effectCount;
        while (k > 0) {
            int length = dataModels.get(dataModels.size() - k).getCompoundParameter().getDimension();
            double[] effect = new double[totalLength];
            int i = 0;
            while (i < length) {
                effect[offset + i] = 1.0;
                ++i;
            }
            designMatrix.addParameter((Parameter)new Parameter.Default("dm." + effectName + (label + 1), effect));
            ++label;
            offset += length;
            --k;
        }
        return effectCount;
    }

    private int getEffectDimensions(List<DataModel> dataModels, int effectCount) {
        int effectDimensions = 0;
        int i = 1;
        while (i <= effectCount) {
            effectDimensions += dataModels.get(dataModels.size() - i).getCompoundParameter().getDimension();
            ++i;
        }
        return effectDimensions;
    }

    private int getTotalNumberOfProfiles(List<DataModel> dataModels) {
        int length = 0;
        for (DataModel dataModel : dataModels) {
            length += dataModel.getCompoundParameter().getDimension();
        }
        return length;
    }

    private int getMaxIdentifier(List<DataModel> dataModels) {
        int max = this.maxOfList(dataModels.get(0).getIdentifiers());
        int i = 1;
        while (i < dataModels.size()) {
            max = Math.max(max, this.maxOfList(dataModels.get(i).getIdentifiers()));
            ++i;
        }
        return max;
    }

    private int maxOfList(List<Integer> integers) {
        int max = integers.get(0);
        int i = 1;
        while (i < integers.size()) {
            max = Math.max(max, integers.get(i));
            ++i;
        }
        return max;
    }

    private int findIdentifier(DataModel dataModel, int id) {
        List<Integer> identifiers = dataModel.getIdentifiers();
        return identifiers.indexOf(id);
    }

    public static CompoundGradient makeDataModelCompoundGradient(List<DataModel> dataModels) {
        ArrayList<GradientWrtParameterProvider.ParameterWrapper> gpp = new ArrayList<GradientWrtParameterProvider.ParameterWrapper>();
        for (DataModel dm : dataModels) {
            GradientProvider gp = (GradientProvider)dm.getLikelihood();
            gpp.add(new GradientWrtParameterProvider.ParameterWrapper(gp, dm.getCompoundParameter(), dm.getLikelihood()));
        }
        return new CompoundDerivative(gpp);
    }

    public static void main(String[] args) {
        int chainLength = 1100000;
        int burnIn = 100000;
        int subSampleFrequency = 1000;
        ArrayList<DataModel> allDataModels = new ArrayList<DataModel>();
        allDataModels.add(new ExtendingEmpiricalDataModel("ForDavid/grids_example_1.csv"));
        allDataModels.add(new ExtendingEmpiricalDataModel("ForDavid/grids_example_2.csv"));
        allDataModels.add(new ExtendingEmpiricalDataModel("ForDavid/grids_example_3.csv"));
        allDataModels.add(new ExtendingEmpiricalDataModel("ForDavid/grids_example_4.csv"));
        HierarchicalMetaAnalysisConfiguration cg = new HierarchicalMetaAnalysisConfiguration();
        cg.effectCount = 2;
        cg.exposureHyperStdDev.add(10.0);
        HierarchicalMetaAnalysis analysis = new HierarchicalMetaAnalysis(allDataModels, cg);
        Runner runner = new Runner(analysis, chainLength, burnIn, subSampleFrequency, cg.seed);
        runner.run();
        runner.processSamples();
    }

    public static class HierarchicalMetaAnalysisConfiguration {
        public double hierarchicalLocationPrimaryHyperStdDev = 1.0;
        public double hierarchicalLocationSecondaryHyperStdDev = 1.0;
        public double gammaHyperPrimaryShape = 1.0;
        public double gammaHyperPrimaryScale = 1.0;
        public double gammaHyperSecondaryShape = 1.0;
        public double gammaHyperSecondaryScale = 1.0;
        public int effectCount = 1;
        public List<Double> exposureHyperLocation = new ArrayList<Double>(Arrays.asList(0.0));
        public List<Double> exposureHyperStdDev = new ArrayList<Double>(Arrays.asList(2.0));
        public double tauShape = 1.0;
        public double tauScale = 1.0;
        public double startingTau = 1.0;
        AdaptationMode mode = AdaptationMode.ADAPTATION_ON;
        public double operatorWeight = 1.0;
        public long seed = 666L;
        public String primaryEffectName = "outcome";
        public String secondaryEffectName = "source";
        public String exposureEffectName = "exposure";
        public boolean includeSecondary = true;
        public boolean includeExposure = true;
        public boolean separateEffectPrior = false;
    }

    static class HierarchicalNormalComponents {
        final List<Parameter> parameters;
        final List<MCMCOperator> operators;
        final List<Likelihood> likelihoods;
        final List<GradientWrtParameterProvider> gradients;

        HierarchicalNormalComponents(List<Parameter> parameters, List<MCMCOperator> operators, List<Likelihood> likelihoods, List<GradientWrtParameterProvider> gradients) {
            this.parameters = parameters;
            this.operators = operators;
            this.likelihoods = likelihoods;
            this.gradients = gradients;
        }
    }
}

