发表文章

[最新] spark多个字段任意类型排序

qq39713718 3月前 1

import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row


class SortRow(mode: Array[Boolean]
              , index: Array[Int]
              , dataType: Array[DataType.DataEnum]) extends Serializable {


  def sort(data: RDD[Row]): RDD[Row] = data.sortBy(x => getIndex(index, x))


  private def getIndex(index: Array[Int], row: Row): Row = {
    val buffer: Array[Any] = new Array[Any](index.length)
    for (c <- 0 until index.length) {
      buffer(c) = row.get(index(c))
    }
    Row.fromSeq(buffer)
  }


  /**
    * 定义隐式转换排序方法
    **/
  implicit val sort = {
    new Ordering[Row] {
      override def compare(left: Row, right: Row) = {
        compareRow(left, right, 0)
      }
    }
  }


  /**
    * 行排序
    *
    * @return
    */
  private[this] def compareRow(implicit left: Row, right: Row, index: Int): Int = {
    var result = mode(index)
    match {
      case true => compareData(left, right, index)
      case false => -compareData(left, right, index)
    }
    if (result == 0 && index < mode.length - 1) {
      result = compareRow(left, right, index + 1)
    }
    result
  }

  /**
    * 数值比较(自定义排序规则)
    *
    * @return
    */
  private[this] def compareData(left: Row, right: Row, index: Int): Int = {
    val dtype = dataType(index)
    /**
      * 缺失处理
      */
    if (left.isNullAt(index)) {
      right.isNullAt(index) match {
        case true => 0
        case false => -1
      }
    } else if (right.isNullAt(index)) {
      1
    } else if (dtype.equals(DataType.Double)) {
      left.getDouble(index).compare(right.getDouble(index))
    } else if (dtype.equals(DataType.Int)) {
      left.getInt(index).compare(right.getInt(index))
    } else
      left.getString(index).compare(right.getString(index))
  }
}

/**
  * 枚举字段类型和对象 根据需要自己添加 实现compare方法即可
  */
object DataType extends Enumeration with Serializable {
  type DataEnum = Value
  val Int = Value("Int")
  val String = Value("String")
  val Double = Value("Double")
}

object SortRow extends App {
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
  val sparkConf = new SparkConf().setAppName("sort").setMaster("local[16]")
  val sc = new SparkContext(sparkConf)
  val data = Array(Row.fromSeq(Array(1, "e", 0.92, 3)), Row.fromSeq(Array(5, "n", null, 5)), Row.fromSeq(Array(3, "m", 0.32, 8)))
  val rdd = sc.parallelize(data)
  val sort = new SortRow(Array(true, false, true), Array(2, 0, 1), Array(DataType.Double, DataType.Int, DataType.String))
  val result = sort.sort(rdd)
  result.collect().foreach(println)
}
相关推荐
最新评论 (0)
返回
发表文章
qq39713718
文章数
2
评论数
0
注册排名
1333867