From a57164fbbb07fcfb9846ce1fda501c46c37dbd0e Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 29 May 2026 14:06:10 +0000 Subject: [PATCH 1/6] Handle hash columns in transform decoders and tighten decode metadata flow Reworks transform decoders so that hash-encoded columns survive the inverse-transform path, and tightens how decoder metadata (column indices, value mappings) is propagated and initialized. - Decoder: pass column-id arrays through decode/decodeFromMap so each decoder knows its own output column range - DecoderRecode: skip recode for hash columns, keep encoded ints passthrough; init metadata from frame consistently - DecoderDummycode: handle hash columns when expanding categorical bits; parallel decode path; sparse-friendly init - DecoderPassThrough / DecoderBin / DecoderComposite / DecoderFactory: consume the new column-id arrays from the dispatch layer - ColumnEncoderFeatureHash: align hash bookkeeping with the decode-side changes - Frame columns (HashMapToInt, StringArray): small support changes consumed by the decoder path above --- .../frame/data/columns/StringArray.java | 4 +- .../runtime/transform/decode/Decoder.java | 32 ++++- .../runtime/transform/decode/DecoderBin.java | 68 ++++++++-- .../transform/decode/DecoderComposite.java | 32 +---- .../transform/decode/DecoderDummycode.java | 117 ++++++++++++------ .../transform/decode/DecoderFactory.java | 28 ++++- .../transform/decode/DecoderPassThrough.java | 23 ++-- .../transform/decode/DecoderRecode.java | 49 ++++---- .../encode/ColumnEncoderFeatureHash.java | 6 +- 9 files changed, 240 insertions(+), 119 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..292fcb52bf5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,7 +607,6 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { @@ -617,7 +616,8 @@ private static double getAsDouble(String s) { else if(ls.equals("false") || ls.equals("f")) return 0; else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw e; // for efficiency + // throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..c9fcc23990a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,15 +44,18 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; - public DecoderBin() { - super(null, null); - } + // public DecoderBin() { + // super(null, null); + // } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..dff85e72dc6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -50,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - public DecoderComposite() { super(null, null); } + // public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..debce027680 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,98 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - _clPos[j] + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +157,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..c2de3ec1df3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - public DecoderPassThrough() { super(null, null); } + // public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..1cf0b7c4b3f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,12 +47,11 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; - public DecoderRecode() { - super(null, null); - } + // public DecoderRecode() { + // super(null, null); + // } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -129,27 +128,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override From 89e7bdce82dc3cc7ed137eda60390af35ba4198f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 8 Jun 2026 14:57:57 +0000 Subject: [PATCH 2/6] Fix sparse dummycode decode and restore decoder no-arg constructors Fix two regressions in the transform decode rewrite that broke encode/decode roundtrips on dummycoded/recoded frames: - DecoderDummycode.decodeSparse compared 0-based sparse column indexes against the 1-based _clPos/_cuPos bounds used by the dense path (in.get(i, k-1)). This shifted every lookup by one column, so the first category was never matched (decoded as null) and all others decoded one code too low. Shift the sparse bounds and index to be 0-based, matching the dense path. - Restore the public no-arg constructors on DecoderComposite, DecoderBin, DecoderPassThrough, and DecoderRecode. Decoder is Externalizable, and Spark broadcasts the top-level decoder via Java serialization, which requires a public no-arg constructor; without it deserialization fails with InvalidClassException on executors. Restores passing of TransformFrameEncodeColmapTest, TransformFrameEncodeDecodeTest, TransformCSVFrameEncodeDecodeTest, and TransformFrameEncodeDecodeTokenTest in single-node and Spark modes. --- .../sysds/runtime/transform/decode/DecoderBin.java | 6 +++--- .../runtime/transform/decode/DecoderComposite.java | 2 +- .../runtime/transform/decode/DecoderDummycode.java | 10 ++++++---- .../runtime/transform/decode/DecoderPassThrough.java | 2 +- .../sysds/runtime/transform/decode/DecoderRecode.java | 6 +++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index c9fcc23990a..79d9b7f3a40 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -49,9 +49,9 @@ public class DecoderBin extends Decoder { private double[][] _binMins = null; private double[][] _binMaxs = null; - // public DecoderBin() { - // super(null, null); - // } + public DecoderBin() { + super(null, null); + } protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index dff85e72dc6..adfef7bbc6d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -47,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - // public DecoderComposite() { super(null, null); } + public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index debce027680..95d7f4fa4c9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -91,9 +91,11 @@ private void decodeSparseRow(FrameBlock out, final SparseBlock sb, int i) { final int[] aix = sb.indexes(i); for(int j = 0; j < _colList.length; j++) { // for each decode column. - // find k, the index in aix, within the range of low and high - final int low = _clPos[j]; - final int high = _cuPos[j]; + // find k, the index in aix, within the range of low and high. + // _clPos/_cuPos are 1-based matrix positions (the dense path reads + // in.get(i, k-1)); the sparse indexes in aix are 0-based, so shift. + final int low = _clPos[j] - 1; + final int high = _cuPos[j] - 1; int h = Arrays.binarySearch(aix, apos, alen, low); // start h at column. if(h < 0) // search gt col index (see binary search) h = Math.abs(h + 1); @@ -101,7 +103,7 @@ private void decodeSparseRow(FrameBlock out, final SparseBlock sb, int i) { if(h < alen && aix[h] >= low && aix[h] < high) { int k = aix[h]; int col = _colList[j] - 1; - out.getColumn(col).set(i, k - _clPos[j] + 1); + out.getColumn(col).set(i, k - low + 1); } // limit the binary search. apos = h; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index c2de3ec1df3..d2e7d59e81f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - // public DecoderPassThrough() { super(null, null); } + public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 1cf0b7c4b3f..a48759493fa 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -49,9 +49,9 @@ public class DecoderRecode extends Decoder private HashMap[] _rcMaps = null; private boolean _onOut = false; - // public DecoderRecode() { - // super(null, null); - // } + public DecoderRecode() { + super(null, null); + } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); From 2e658369ef519bd2c57655e0bbf68a9760ccc7fd Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 9 Jun 2026 13:55:49 +0000 Subject: [PATCH 3/6] Add component tests for transform decoder hash and metadata handling Cover the decoder paths touched by the hash-column and decode-metadata changes: parallel block decode equals serial decode, the sparse and dense dummycode decode paths agree, feature-hash columns decode through dummycode via the magic domain-size metadata, and bin columns whose source position is shifted by dummycoding of another column. Add exact inverse round-trip checks for recode and dummycode to validate the sparse binary-search decode against ground truth. --- .../TransformDecodeRoundTripTest.java | 167 ++++++++++++++++ .../frame/transform/TransformDecodeTest.java | 186 ++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java new file mode 100644 index 00000000000..9e8b55df29b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +/** + * Exact inverse correctness tests for the transform decoders. Recode and dummycode are lossless category encodings, so a + * decode of the encoded matrix must reconstruct the original categorical frame. These tests assert exact reconstruction + * for the dense path, the sparse path, and the parallel path so that the dummycode sparse binary search and the parallel + * block split are validated against ground truth rather than only against each other. + */ +public class TransformDecodeRoundTripTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeRoundTripTest.class.getName()); + + @Before + public void setUp() { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + } + + private static FrameBlock categoricalFrame() { + final String[] values = new String[] { + "apple", "banana", "apple", "cherry", "banana", "date", "apple", "cherry", "date", "banana", "elderberry", + "apple", "fig", "banana", "cherry", "apple", "date", "fig", "elderberry", "banana"}; + final FrameBlock f = new FrameBlock(new ValueType[] {ValueType.STRING}); + f.ensureAllocatedColumns(values.length); + for(int i = 0; i < values.length; i++) + f.set(i, 0, values[i]); + return f; + } + + @Test + public void recodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1]}", false, 1); + } + + @Test + public void recodeReconstructsOriginalSparse() { + roundTrip("{ids:true, recode:[1]}", true, 1); + } + + @Test + public void recodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1]}", false, 4); + } + + @Test + public void dummycodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 1); + } + + @Test + public void dummycodeReconstructsOriginalSparse() { + // the one-hot encoded matrix is sparse, so this drives the dummycode sparse binary-search decode path + roundTrip("{ids:true, recode:[1], dummycode:[1]}", true, 1); + } + + @Test + public void dummycodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 4); + } + + /** + * Binning a column while a different column is dummycoded shifts the bin column's source position in the encoded + * matrix. The bin decoder must rebuild that source-column mapping from the dummycode domain sizes. This asserts the + * dense, sparse, and parallel decode paths agree for that layout (bin output is lossy, so exact reconstruction is + * not asserted, only cross-mode consistency and dimensions). + */ + @Test + public void binWithDummycodeOnOtherColumnConsistency() { + final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = new MatrixBlock(); + dense.copy(encoded); + if(dense.isInSparseFormat()) + dense.sparseToDense(); + + final MatrixBlock sparse = new MatrixBlock(); + sparse.copy(encoded); + if(!sparse.isInSparseFormat()) + sparse.denseToSparse(); + + final FrameBlock reference = decodeOnce(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decodeOnce(spec, colnames, meta, dense, 4); + final FrameBlock fromSparse = decodeOnce(spec, colnames, meta, sparse, 1); + + org.junit.Assert.assertEquals(original.getNumRows(), reference.getNumRows()); + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private void roundTrip(String spec, boolean sparse, int k) { + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + if(sparse && !encoded.isInSparseFormat()) + encoded.denseToSparse(); + else if(!sparse && encoded.isInSparseFormat()) + encoded.sparseToDense(); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), k); + + TestUtils.compareFrames(original, decoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " (sparse=" + sparse + ", k=" + k + ") : " + e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java new file mode 100644 index 00000000000..254937c20da --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Component tests for the transform decoders. These exercise the row-block and parallel decode paths, the sparse and + * dense dummycode decode paths, the binning source-column offset mapping, and feature-hash column handling end-to-end + * through an encode followed by decode round trip. + */ +@RunWith(value = Parameterized.class) +public class TransformDecodeTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeTest.class.getName()); + + private final FrameBlock data; + private final int k; + + public TransformDecodeTest(FrameBlock data, int k) { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + this.data = data; + this.k = k; + } + + @Parameters + public static Collection data() { + final ArrayList tests = new ArrayList<>(); + final int[] threads = new int[] {1, 4}; + try { + final FrameBlock[] blocks = new FrameBlock[] { + // single low-cardinality categorical column + TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231), + // single categorical column with nulls + TestUtils.generateRandomFrameBlock(64, new ValueType[] {ValueType.UINT4}, 99, 0.2), + // multi column: dummycode/bin on col1 must offset the trailing passthrough columns + TestUtils.generateRandomFrameBlock(120, + new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.FP32}, 17), + // large enough to split into multiple row blocks in the parallel decode path + TestUtils.generateRandomFrameBlock(2500, new ValueType[] {ValueType.UINT4}, 7)}; + + for(FrameBlock block : blocks) + for(int k : threads) + tests.add(new Object[] {block, k}); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + return tests; + } + + @Test + public void testPassThrough() { + decodeConsistency("{ids:true}"); + } + + @Test + public void testRecode() { + decodeConsistency("{ids:true, recode:[1]}"); + } + + @Test + public void testDummycode() { + decodeConsistency("{ids:true, recode:[1], dummycode:[1]}"); + } + + @Test + public void testBinWidth() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"); + } + + @Test + public void testBinHeight() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}"); + } + + @Test + public void testBinSingleBin() { + // numbins:1 forces the key==0 branch in the bin decoder + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:1}]}"); + } + + @Test + public void testHashToDummy() { + // feature-hash columns carry their domain size as the magic "¿K" metadata value, which the dummycode decoder + // must parse to reconstruct the one-hot column ranges + decodeConsistency("{ids:true, hash:[1], K:8, dummycode:[1]}"); + } + + @Test + public void testHashToDummyDomain1() { + decodeConsistency("{ids:true, hash:[1], K:1, dummycode:[1]}"); + } + + /** + * Encode the data, then decode the encoded matrix in three ways: serial dense, parallel dense, and serial sparse. + * All three must produce identical frames. This jointly exercises the parallel block-decode path in + * {@link Decoder#decode(MatrixBlock, FrameBlock, int)} and the separate sparse / dense dummycode decode paths. + */ + private void decodeConsistency(String spec) { + try { + final String[] colnames = data.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, data.getNumColumns(), null); + final MatrixBlock encoded = encoder.encode(data, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = forceDense(encoded); + final MatrixBlock sparse = forceSparse(encoded); + + final FrameBlock reference = decode(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decode(spec, colnames, meta, dense, k); + final FrameBlock fromSparse = decode(spec, colnames, meta, sparse, 1); + + assertEquals("decoded rows must match input rows", data.getNumRows(), reference.getNumRows()); + + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decode(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private static MatrixBlock forceDense(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(out.isInSparseFormat()) + out.sparseToDense(); + return out; + } + + private static MatrixBlock forceSparse(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(!out.isInSparseFormat()) + out.denseToSparse(); + return out; + } +} From 7e357a59d2539ed5a2d099d89c2c906529d37f4c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 9 Jun 2026 22:15:23 +0000 Subject: [PATCH 4/6] Speed up boolean-token fallback in StringArray.getAsDouble Replace the toLowerCase plus equals chain in the parse fallback with a length-based dispatch: a single char compare for the 1-char "t"/"f" tokens and compareToIgnoreCase for "true"/"false", matching the idiom already used in DoubleArray.parseDouble. This avoids allocating a lower-cased copy and rejects non-boolean strings immediately. Restore throwing DMLRuntimeException on unparseable input. The previous re-throw of the raw NumberFormatException changed the exception type and broke callers such as Array.extractDouble that expect DMLRuntimeException; the throw path is the genuinely-exceptional case, so the wrapping cost is irrelevant there. --- .../frame/data/columns/StringArray.java | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 292fcb52bf5..8156a98cd35 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -610,14 +610,22 @@ private static double getAsDouble(String s) { return DoubleArray.parseDouble(s); } catch(Exception e) { - String ls = s.toLowerCase(); - if(ls.equals("true") || ls.equals("t")) + // Fallback for boolean-like tokens. Dispatch on length first so non-boolean strings are + // rejected immediately, and avoid allocating a lower-cased copy by comparing case-insensitively + // (single char compare for the 1-char tokens). + final int len = s.length(); + if(len == 1) { + final char c = s.charAt(0); + if(c == 't' || c == 'T') + return 1; + else if(c == 'f' || c == 'F') + return 0; + } + else if(len == 4 && s.compareToIgnoreCase("true") == 0) return 1; - else if(ls.equals("false") || ls.equals("f")) + else if(len == 5 && s.compareToIgnoreCase("false") == 0) return 0; - else - throw e; // for efficiency - // throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw new DMLRuntimeException("Unable to change to double: " + s, e); } } From d75166dd78a50a1cb3c870b83d3d44f720390220 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 16 Jun 2026 22:14:02 +0000 Subject: [PATCH 5/6] Fix DecoderBin serialization and remove dead _numBins field DecoderBin gained derived decode state (_srcCols/_dcCols) that was not captured by writeExternal/readExternal, so a decoder broadcast to Spark executors (which do not re-run initMetaData) would NPE on every binned transformdecode. Serialize _srcCols/_dcCols alongside the bin boundaries, mirroring DecoderPassThrough. Register DecoderBin in DecoderFactory.getDecoderType/createInstance so a composite decoder containing a bin decoder can be serialized at all; previously it threw "Unsupported decoder type" before reaching DecoderBin.writeExternal. Remove the _numBins field: it was allocated but never populated, so writeExternal wrote a zero length for every column and dropped all bin boundaries on deserialization. The per-column bin count is recovered from the boundary array length instead, and readExternal now allocates the boundary arrays it reads into. Add serialization round-trip tests (plain bin and bin-with-dummycode) plus tests for the bin source-column offset mapping, the key==0 bin branch, and the StringArray boolean-token getAsDouble fallback. --- .../runtime/transform/decode/DecoderBin.java | 25 +++- .../transform/decode/DecoderFactory.java | 3 + .../frame/array/CustomArrayTests.java | 25 ++++ .../TransformDecodeRoundTripTest.java | 132 +++++++++++++++++- .../frame/transform/TransformDecodeTest.java | 2 +- 5 files changed, 178 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index 79d9b7f3a40..54d1b86a2c5 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -43,7 +43,6 @@ public class DecoderBin extends Decoder { private static final long serialVersionUID = -3784249774608228805L; // a) column bin boundaries - private int[] _numBins; private int[] _dcCols = null; private int[] _srcCols = null; private double[][] _binMins = null; @@ -108,7 +107,6 @@ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { @Override public void initMetaData(FrameBlock meta) { //initialize bin boundaries - _numBins = new int[_colList.length]; _binMins = new double[_colList.length][]; _binMaxs = new double[_colList.length][]; @@ -162,29 +160,46 @@ public void initMetaData(FrameBlock meta) { @Override public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); + // bin boundaries; the per-column bin count is the length of the boundary arrays for( int i=0; i<_colList.length; i++ ) { - int len = _numBins[i]; + int len = _binMins[i].length; out.writeInt(len); for(int j=0; j t = ArrayFactory.create(truthy); + for(int i = 0; i < t.size(); i++) + assertEquals(1.0, t.getAsDouble(i), 0.0); + Array f = ArrayFactory.create(falsy); + for(int i = 0; i < f.size(); i++) + assertEquals(0.0, f.getAsDouble(i), 0.0); + } + + @Test(expected = DMLRuntimeException.class) + public void stringArrayGetDoubleInvalidThrows() { + // a token that is neither numeric nor a boolean word/char must throw + ArrayFactory.create(new String[] {"notabool"}).getAsDouble(0); + } + + @Test(expected = DMLRuntimeException.class) + public void stringArrayGetDoubleAmbiguousLengthThrows() { + // length matches neither 1, 4, nor 5 boolean tokens -> reject + ArrayFactory.create(new String[] {"tru"}).getAsDouble(0); + } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java index 9e8b55df29b..5156b78e2e9 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -21,6 +21,11 @@ import static org.junit.Assert.fail; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; @@ -99,10 +104,44 @@ public void dummycodeReconstructsOriginalParallel() { */ @Test public void binWithDummycodeOnOtherColumnConsistency() { - final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}"; + // bin column (1) precedes the dummycode column (2): the bin decoder takes the direct + // source-column path because no expanded column sits before it + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); + binConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}", original); + } + + /** + * Dummycode on an earlier column (1) shifts the bin column (2) to the right in the encoded matrix. The bin decoder + * must walk the dummycode domain sizes to recover the bin column's true source position. This drives the + * non-magic offset branch of the bin source-column mapping. + */ + @Test + public void binAfterDummycodeOnEarlierColumnConsistency() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.UINT4, ValueType.FP32, ValueType.UINT8}, 4242); + binConsistency("{ids:true, recode:[1], dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", original); + } + + /** + * Same right-shift as above, but the earlier column is feature-hashed before being dummycoded. The hash domain + * size is stored as the magic "¿K" metadata value, so the bin source-column mapping must take the magic-value + * branch to compute the offset. + */ + @Test + public void binAfterHashDummycodeOnEarlierColumnConsistency() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.UINT4, ValueType.FP32, ValueType.UINT8}, 4242); + binConsistency("{ids:true, hash:[1], K:6, dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", + original); + } + + /** + * Encode then decode the dense, parallel and sparse paths and assert they agree. Bin output is lossy, so only + * cross-mode consistency and row count are asserted (not exact reconstruction). + */ + private void binConsistency(String spec, FrameBlock original) { try { - final FrameBlock original = TestUtils.generateRandomFrameBlock(150, - new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); final String[] colnames = original.getColumnNames(); final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), @@ -134,6 +173,93 @@ public void binWithDummycodeOnOtherColumnConsistency() { } } + /** + * The bin encoder always emits codes >= 1, but the decoder defensively handles a 0 code by mapping it to the + * first bin's lower boundary. Inject a 0 into an otherwise validly encoded matrix to exercise that branch. + */ + @Test + public void binDecodeZeroCodeUsesFirstBinBoundary() { + final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(50, new ValueType[] {ValueType.FP32}, 13); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + encoded.set(0, 0, 0); // force a 0 bin code + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + final double first = Double.parseDouble(decoded.get(0, 0).toString()); + final double second = Double.parseDouble(decoded.get(1, 0).toString()); + // the 0-coded row decodes to the first bin lower bound, which is <= any properly binned center + org.junit.Assert.assertTrue("0-code must map to the lowest bin boundary", first <= second); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * Spark broadcasts the decoder to executors via Java serialization without re-running initMetaData, so the + * decoder must round-trip all of its decode state through writeExternal/readExternal. Decode with a freshly + * deserialized decoder and assert it matches the in-memory decode. Covers plain bin and bin-with-dummycode + * (the latter exercises the serialized _srcCols/_dcCols source-column mapping). + */ + @Test + public void binDecoderSurvivesSerialization() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(80, new ValueType[] {ValueType.FP32}, 21); + serializeRoundTrip("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}", original); + } + + @Test + public void binWithDummycodeDecoderSurvivesSerialization() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(80, + new ValueType[] {ValueType.UINT4, ValueType.FP32}, 21); + serializeRoundTrip("{ids:true, recode:[1], dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", + original); + } + + private void serializeRoundTrip(String spec, FrameBlock original) { + try { + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock expected = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + final Decoder restored = serializeDeserialize(decoder); + final FrameBlock actual = restored.decode(encoded, new FrameBlock(restored.getSchema()), 1); + + TestUtils.compareFrames(expected, actual, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static Decoder serializeDeserialize(Decoder decoder) throws Exception { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try(ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject(decoder); + } + try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bos.toByteArray()))) { + return (Decoder) ois.readObject(); + } + } + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java index 254937c20da..1c29e4f6a77 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java @@ -117,7 +117,7 @@ public void testBinHeight() { @Test public void testBinSingleBin() { - // numbins:1 forces the key==0 branch in the bin decoder + // numbins:1 collapses every value into a single bin, exercising the degenerate boundary handling decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:1}]}"); } From f82580098c60ab132b2e38a10f55b555d40d6417 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 17 Jun 2026 13:14:32 +0000 Subject: [PATCH 6/6] Recover hash domain size from meta cell instead of numDistinct For feature-hashed columns the hash domain size K is stored in the single transform meta cell, not materialized as numDistinct rows like recode/bin. Stop overloading numDistinct=K on the hash meta column and instead pass the set of dummycoded hash columns from DecoderFactory to the dummycode/bin/ passthrough decoders, which read K from the cell when sizing the dummycode expansion. This keeps numDistinct semantically the meta column's own cardinality and avoids any sentinel in the cell value. Also trim verbose comments introduced in the transform decoders and remove the dead commented-out _rcMapsDirect block (and its unused max accumulator) in DecoderRecode. --- .../frame/data/columns/StringArray.java | 4 +-- .../runtime/transform/decode/Decoder.java | 30 ++++++++++++++++ .../runtime/transform/decode/DecoderBin.java | 19 +++++----- .../transform/decode/DecoderDummycode.java | 24 +++++-------- .../transform/decode/DecoderFactory.java | 17 ++++++--- .../transform/decode/DecoderPassThrough.java | 17 ++++----- .../transform/decode/DecoderRecode.java | 13 ------- .../encode/ColumnEncoderFeatureHash.java | 8 ++--- .../TransformDecodeRoundTripTest.java | 35 +++++++++++++++++++ 9 files changed, 105 insertions(+), 62 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 8156a98cd35..1541f16c96d 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -610,9 +610,7 @@ private static double getAsDouble(String s) { return DoubleArray.parseDouble(s); } catch(Exception e) { - // Fallback for boolean-like tokens. Dispatch on length first so non-boolean strings are - // rejected immediately, and avoid allocating a lower-cased copy by comparing case-insensitively - // (single char compare for the 1-char tokens). + // fallback for boolean-like tokens, without allocating a lower-cased copy final int len = s.length(); if(len == 1) { final char c = s.charAt(0); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 70834675ded..e8e277cf1fb 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -28,13 +28,16 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.UtilFunctions; /** * Base class for all transform decoders providing both a row and block @@ -48,11 +51,38 @@ public abstract class Decoder implements Externalizable{ protected ValueType[] _schema; protected int[] _colList; protected String[] _colnames = null; + // dummycoded columns that were feature-hashed: domain size K is read from the meta cell, not + // numDistinct. Only used during initMetaData (driver side), so not serialized. + protected transient int[] _dcHashCols = null; + protected Decoder(ValueType[] schema, int[] colList) { _schema = schema; _colList = colList; } + protected boolean isHashCol(int colID) { + return ArrayUtils.contains(_dcHashCols, colID); + } + + /** + * Domain size of a dummycoded source column: the hash domain K from the meta cell for + * feature-hashed columns, otherwise the column's {@code numDistinct} (0 when unset). + * + * @param meta transform meta frame + * @param colID 1-based column id of the dummycoded source column + * @param isHash whether the column was feature-hashed + * @return the domain size, never negative + */ + protected static int getNumDummycodeDistinct(FrameBlock meta, int colID, boolean isHash) { + if(isHash) { + Object o = meta.get(0, colID - 1); + return (o == null) ? 0 : (int) UtilFunctions.parseToLong(o.toString()); + } + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + int ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + return Math.max(ndist, 0); + } + public ValueType[] getSchema() { return _schema; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index 54d1b86a2c5..cd127d64945 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,7 +28,6 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; -import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -53,8 +52,13 @@ public DecoderBin() { } protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { + this(schema, binCols, dcCols, null); + } + + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols, int[] hashCols) { super(schema, binCols); _dcCols = dcCols; + _dcHashCols = hashCols; } @Override @@ -139,14 +143,8 @@ public void initMetaData(FrameBlock meta) { ix1 ++; } else { //_colList[ix1] > _dcCols[ix2] - ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - String v = meta.getString(0, _dcCols[ix2]-1); - if(v.length() > 1 && v.charAt(0) == '¿'){ - off += UtilFunctions.parseToLong(v.substring(1)) -1; - } - else { - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; - } + int dcCol = _dcCols[ix2]; + off += getNumDummycodeDistinct(meta, dcCol, isHashCol(dcCol)) - 1; ix2 ++; } } @@ -169,8 +167,7 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeDouble(_binMaxs[i][j]); } } - // source-column mapping derived from dummycode/hash domain sizes (rebuilt in initMetaData, - // but persisted here because Spark broadcasts the decoder without re-running initMetaData) + // source-column mapping (rebuilt in initMetaData, but persisted for Spark broadcast) out.writeInt(_srcCols.length); for(int i = 0; i < _srcCols.length; i++) out.writeInt(_srcCols[i]); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 95d7f4fa4c9..95400d9944e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -29,9 +29,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.UtilFunctions; /** * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given @@ -45,8 +43,13 @@ public class DecoderDummycode extends Decoder { private int[] _cuPos = null; protected DecoderDummycode(ValueType[] schema, int[] dcCols) { + this(schema, dcCols, null); + } + + protected DecoderDummycode(ValueType[] schema, int[] dcCols, int[] hashCols) { // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); + _dcHashCols = hashCols; } @Override @@ -91,9 +94,7 @@ private void decodeSparseRow(FrameBlock out, final SparseBlock sb, int i) { final int[] aix = sb.indexes(i); for(int j = 0; j < _colList.length; j++) { // for each decode column. - // find k, the index in aix, within the range of low and high. - // _clPos/_cuPos are 1-based matrix positions (the dense path reads - // in.get(i, k-1)); the sparse indexes in aix are 0-based, so shift. + // find the set bit in [low, high); _clPos/_cuPos are 1-based, aix is 0-based final int low = _clPos[j] - 1; final int high = _cuPos[j] - 1; int h = Arrays.binarySearch(aix, apos, alen, low); // start h at column. @@ -166,17 +167,8 @@ public void initMetaData(FrameBlock meta) { _cuPos = new int[_colList.length]; // col upper pos for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; - String v = meta.getString(0, colID - 1); - int ndist; - if(v.length() > 1 && v.charAt(0) == '¿') { - ndist = UtilFunctions.parseToInt(v.substring(1)); - } - else { - ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); - } - - ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + // hash columns store the domain size K in the meta cell; others use numDistinct + int ndist = getNumDummycodeDistinct(meta, colID, isHashCol(colID)); _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 1fc97385479..04ffca7c3d0 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -78,6 +78,12 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] // remove hash recoded. // todo potentially wrong and remove? rcIDs = except(rcIDs, hcIDs); + // dummycoded hash columns: domain size K lives in the meta cell, so the decoders + // need to know which dummycoded columns to read it from + List hcdcIDs = new ArrayList<>(dcIDs); + hcdcIDs.retainAll(hcIDs); + int[] hashCols = ArrayUtils.toPrimitive(hcdcIDs.toArray(new Integer[0])); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); // set the remaining columns to passthrough. @@ -86,8 +92,9 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] ptIDs = except(ptIDs, rcIDs); // binned columns ptIDs = except(ptIDs, binIDs); - // hashed columns - ptIDs = except(ptIDs, hcIDs); // remove hashed columns + // dummycoded columns (incl. dummycoded hash) are rebuilt by the dummycode decoder; + // hash columns without dummycode stay in passthrough so their bucket code survives + ptIDs = except(ptIDs, dcIDs); //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { @@ -102,11 +109,11 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } if( !rcIDs.isEmpty() ) { // todo figure out if we need to handle rc columns with regards to dictionary offsets. @@ -116,7 +123,7 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] if( !ptIDs.isEmpty() ) { ldecoders.add(new DecoderPassThrough(schema, ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } //create composite decoder of all created decoders diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index d2e7d59e81f..bf86e392ef2 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -28,9 +28,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.UtilFunctions; /** * Simple atomic decoder for passing through numeric columns to the output. @@ -45,8 +43,13 @@ public class DecoderPassThrough extends Decoder private int[] _srcCols = null; protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { + this(schema, ptCols, dcCols, null); + } + + protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols, int[] hashCols) { super(schema, ptCols); _dcCols = dcCols; + _dcHashCols = hashCols; } public DecoderPassThrough() { super(null, null); } @@ -112,14 +115,8 @@ public void initMetaData(FrameBlock meta) { ix1 ++; } else { //_colList[ix1] > _dcCols[ix2] - ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - String v = meta.getString( 0,_dcCols[ix2]-1); - if(v.length() > 1 && v.charAt(0) == '¿'){ - off += UtilFunctions.parseToLong(v.substring(1)) -1; - } - else { - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; - } + int dcCol = _dcCols[ix2]; + off += getNumDummycodeDistinct(meta, dcCol, isHashCol(dcCol)) - 1; ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index a48759493fa..a73631a0abf 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -124,7 +124,6 @@ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { public void initMetaData(FrameBlock meta) { //initialize recode maps according to schema _rcMaps = new HashMap[_colList.length]; - long[] max = new long[_colList.length]; for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - // _rcMapsDirect = new Object[_rcMaps.length][]; - // for( int i=0; i<_rcMaps.length; i++ ) { - // Object[] arr = new Object[(int)max[i]]; - // for(Entry e1 : _rcMaps[i].entrySet()) - // arr[e1.getKey().intValue()-1] = e1.getValue(); - // _rcMapsDirect[i] = arr; - // } - // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 361c9c52135..cd9a583d60f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,9 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - // set metadata of hash columns to magical hash value + k - meta.set(0, _colID - 1, String.format("¿%d" , _K)); - + // store the hash domain size K in the single meta cell + meta.set(0, _colID - 1, String.valueOf(_K)); + return meta; } @@ -156,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); + _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); } @Override diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java index 5156b78e2e9..4ca7160341c 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -260,6 +260,41 @@ private static Decoder serializeDeserialize(Decoder decoder) throws Exception { } } + /** + * Feature hashing is non-invertible, so the decode contract for a hash column that is NOT dummycoded is that the + * encoded bucket code passes through unchanged. Regression test: a hash-only column must not be dropped from the + * decoded frame (it previously was, because hash columns were excluded from passthrough). + */ + @Test + public void hashWithoutDummycodeDecodesToBucketCode() { + final String spec = "{ids:true, hash:[1], K:8}"; + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + org.junit.Assert.assertEquals(1, decoded.getNumColumns()); + for(int i = 0; i < original.getNumRows(); i++) { + final Object v = decoded.get(i, 0); + org.junit.Assert.assertNotNull("hash column must survive decode at row " + i, v); + org.junit.Assert.assertEquals("hash bucket code must pass through at row " + i, encoded.get(i, 0), + Double.parseDouble(v.toString()), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); return decoder.decode(in, new FrameBlock(decoder.getSchema()), k);