/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

public class TensorTensorBinaryOpPartitionFunction
implements PairFlatMapFunction<Iterator<Tuple2<TensorIndexes, TensorBlock>>, TensorIndexes, TensorBlock> {
    private static final long serialVersionUID = 8029096658247920867L;
    private BinaryOperator _op;
    private PartitionedBroadcast<TensorBlock> _ptV;
    private boolean[] _replicateDim;

    public TensorTensorBinaryOpPartitionFunction(BinaryOperator op, PartitionedBroadcast<TensorBlock> binput, boolean[] replicateDim) {
        this._op = op;
        this._ptV = binput;
        this._replicateDim = replicateDim;
    }

    public LazyIterableIterator<Tuple2<TensorIndexes, TensorBlock>> call(Iterator<Tuple2<TensorIndexes, TensorBlock>> arg0) throws Exception {
        return new MapBinaryPartitionIterator(arg0);
    }

    private class MapBinaryPartitionIterator
    extends LazyIterableIterator<Tuple2<TensorIndexes, TensorBlock>> {
        public MapBinaryPartitionIterator(Iterator<Tuple2<TensorIndexes, TensorBlock>> in) {
            super(in);
        }

        @Override
        protected Tuple2<TensorIndexes, TensorBlock> computeNext(Tuple2<TensorIndexes, TensorBlock> arg) {
            TensorIndexes ix = (TensorIndexes)arg._1();
            TensorBlock in1 = (TensorBlock)arg._2();
            DataCharacteristics dc = TensorTensorBinaryOpPartitionFunction.this._ptV.getDataCharacteristics();
            int[] index = new int[dc.getNumDims()];
            for (int i = 0; i < index.length; ++i) {
                index[i] = TensorTensorBinaryOpPartitionFunction.this._replicateDim[i] ? 1 : (int)ix.getIndex(i);
            }
            TensorBlock in2 = (TensorBlock)TensorTensorBinaryOpPartitionFunction.this._ptV.getBlock(index);
            TensorBlock ret = in1.binaryOperations(TensorTensorBinaryOpPartitionFunction.this._op, in2, new TensorBlock());
            return new Tuple2((Object)ix, (Object)ret);
        }
    }
}

