/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class SubsampleToMinFederatedScheme
extends DataPartitionFederatedScheme {
    @Override
    public DataPartitionFederatedScheme.Result partition(MatrixObject features, MatrixObject labels, int seed) {
        List<MatrixObject> pFeatures = SubsampleToMinFederatedScheme.sliceFederatedMatrix(features);
        List<MatrixObject> pLabels = SubsampleToMinFederatedScheme.sliceFederatedMatrix(labels);
        List<Double> weightingFactors = SubsampleToMinFederatedScheme.getWeightingFactors(pFeatures, SubsampleToMinFederatedScheme.getBalanceMetrics(pFeatures));
        int min_rows = Integer.MAX_VALUE;
        for (MatrixObject pFeature : pFeatures) {
            min_rows = pFeature.getNumRows() < (long)min_rows ? Math.toIntExact(pFeature.getNumRows()) : min_rows;
        }
        for (int i = 0; i < pFeatures.size(); ++i) {
            FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
            FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
            Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, min_rows)));
            try {
                FederatedResponse response = udfResponse.get();
                if (!response.isSuccessful()) {
                    throw new DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: subsample UDF returned fail");
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: executing subsample UDF failed" + e.getMessage());
            }
            DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(min_rows);
            pFeatures.get(i).updateDataCharacteristics(update);
            update = pLabels.get(i).getDataCharacteristics().setRows(min_rows);
            pLabels.get(i).updateDataCharacteristics(update);
        }
        return new DataPartitionFederatedScheme.Result(pFeatures, pLabels, pFeatures.size(), SubsampleToMinFederatedScheme.getBalanceMetrics(pFeatures), weightingFactors);
    }

    private static class subsampleDataOnFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = 2213790859544004286L;
        private final int _seed;
        private final int _min_rows;

        protected subsampleDataOnFederatedWorker(long[] inIDs, int seed, int min_rows) {
            super(inIDs);
            this._seed = seed;
            this._min_rows = min_rows;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject features = (MatrixObject)data[0];
            MatrixObject labels = (MatrixObject)data[1];
            if (features.getNumRows() > (long)this._min_rows) {
                MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(this._min_rows, Math.toIntExact(features.getNumRows()), this._seed);
                DataPartitionFederatedScheme.subsampleTo(features, subsampleMatrixBlock);
                DataPartitionFederatedScheme.subsampleTo(labels, subsampleMatrixBlock);
            }
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

