您的位置:首页 > 编程语言

spark core 2.0 TaskSchedulerImpl 源代码解析

2016-11-08 11:30 225 查看
TaskSchedulerImpl 继承了TaskScheduler类,TaskScheduler是一个底层的任务调度接口,当前只有一个TaskSchedulerImpl一个实现。 这个接口允许嵌入不同的任务调度,每一个任务调度器为一个SparkContext调度任务。 这些调度者从DAGScheduler得到一个任务集合,并且负责把这些任务发送到集群里运行, 如果失败则重新运行。它们把事件返回DAGScheduler.

/**
* Low-level task scheduler interface, currently implemented exclusively by
* [[org.apache.spark.scheduler.TaskSchedulerImpl]].
* This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
* them, retrying if there are failures, and mitigating stragglers. They return events to the
* DAGScheduler.
*/
private[spark] trait TaskScheduler {

private val appId = "spark-application-" + System.currentTimeMillis

def rootPool: Pool

def schedulingMode: SchedulingMode

def start(): Unit

// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
// wait for slave registrations, etc.
def postStartHook() { }

// Disconnect from the cluster.
def stop(): Unit

// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit

// Cancel a stage.
def cancelTasks(stageId: Int, interruptThread: Boolean): Unit

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
def executorHeartbeatReceived(
execId: String,
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
blockManagerId: BlockManagerId): Boolean

/**
* Get an application ID associated with the job.
*
* @return An application ID
*/
def applicationId(): String = appId

/**
* Process a lost executor
*/
def executorLost(executorId: String, reason: ExecutorLossReason): Unit

/**
* Get an application's attempt ID associated with the job.
*
* @return An application's Attempt ID
*/
def applicationAttemptId(): Option[String]

}

TaskSchedulerImpl 通过SchedulerBackend,为多种类型的集群调度任务。通过使用LocakScheudlerBackend,并且设置isLocal为true,它可以以本地方法运行,它处理共同逻辑,像决定不同作业的调度顺序,唤醒来启动推测执行的任务,等等。

客户端首先应该调用initialize()和start()方法,然后通过runnTasks来提交任务集合。

线程:SchedulerBackend和任务提交客户端可能从多个线程调用这个类,所以需要在公共API方法上加锁来保持它的状态。除此之外,一些ScheulerBackend当向此类发送事件时,他们会先在自身进行同步,然后再在此类获得一个锁,所以我们需要保证当我们锁定自己时,不要再尝试锁定backend。

/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
* It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting
* isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking
* up to launch speculative tasks, etc.
*
* Clients should first call initialize() and start(), then submit task sets through the
* runTasks method.
*
* THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* [[SchedulerBackend]]s synchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
private[spark] class TaskSchedulerImpl(
val sc: SparkContext,
val maxTaskFailures: Int,
isLocal: Boolean = false)
extends TaskScheduler with Logging
{
def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4))

val conf = sc.conf

// How often to check for speculative tasks
val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms")

private val speculationScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation")

// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s")

// CPUs to request per task
val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)

// Incrementing task IDs
val nextTaskId = new AtomicLong(0)

// Number of tasks running on each executor
private val executorIdToTaskCount = new HashMap[String, Int]

def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap

// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
protected val executorsByHost = new HashMap[String, HashSet[String]]

protected val hostsByRack = new HashMap[String, HashSet[String]]

protected val executorIdToHost = new HashMap[String, String]

// Listener object to pass upcalls into
var dagScheduler: DAGScheduler = null

var backend: SchedulerBackend = null

val mapOutputTracker = SparkEnv.get.mapOutputTracker

var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
// default scheduler is FIFO
private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO")
val schedulingMode: SchedulingMode = try {
SchedulingMode.withName(schedulingModeConf.toUpperCase)
} catch {
case e: java.util.NoSuchElementException =>
throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf")
}

// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)

override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}

def initialize(backend: SchedulerBackend) {
this.backend = backend
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool, conf)
case _ =>
throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode")
}
}
schedulableBuilder.buildPools()
}

def newTaskId(): Long = nextTaskId.getAndIncrement()

override def start() {
backend.start()

if (!isLocal && conf.getBoolean("spark.speculation", false)) {
logInfo("Starting speculative execution thread")
speculationScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
checkSpeculatableTasks()
}
}, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
}
}

override def postStartHook() {
waitBackendReady()
}

override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient resources")
} else {
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}
hasReceivedTask = true
}
backend.reviveOffers()
}

// Label as private[scheduler] to allow tests to swap in different task set managers if necessary
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures)
}

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
//    In this case, send a kill signal to the executors to kill the task and then abort
//    the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
//    simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
}
}

/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.stageAttemptId
if (taskSetsForStage.isEmpty) {
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
}
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
}

private def resourceOfferSingleTaskSet(
taskSet: TaskSetManager,
maxLocality: TaskLocality,
shuffledOffers: Seq[WorkerOffer],
availableCpus: Array[Int],
tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
var launchedTask = false
for (i <- 0 until shuffledOffers.size) {
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
if (availableCpus(i) >= CPUS_PER_TASK) {
try {
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToTaskCount(execId) += 1
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
} catch {
case e: TaskNotSerializableException =>
logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
// Do not offer resources for this task, but don't throw an error to allow other
// task sets to be submitted.
return launchedTask
}
}
}
return launchedTask
}

/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
executorIdToTaskCount.getOrElseUpdate(o.executorId, 0)
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorAdded(o.executorId, o.host)
newExecAvail = true
}
for (rack <- getRackForHost(o.host)) {
hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
}
}

// Randomly shuffle offers to avoid always placing tasks on the same set of workers.
val shuffledOffers = Random.shuffle(offers)
// Build a list of tasks to assign to each worker.
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
if (newExecAvail) {
taskSet.executorAdded()
}
}

// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
// of locality levels so that it gets a chance to launch local tasks on all of them.
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
do {
launchedTask = resourceOfferSingleTaskSet(
taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
} while (launchedTask)
}

if (tasks.size > 0) {
hasLaunchedTask = true
}
return tasks
}

def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
var reason: Option[ExecutorLossReason] = None
synchronized {
try {
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)

if (executorIdToTaskCount.contains(execId)) {
reason = Some(
SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
removeExecutor(execId, reason.get)
failedExecutor = Some(execId)
}
}
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid).foreach { execId =>
if (executorIdToTaskCount.contains(execId)) {
executorIdToTaskCount(execId) -= 1
}
}
}
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor.isDefined) {
assert(reason.isDefined)
dagScheduler.executorLost(failedExecutor.get, reason.get)
backend.reviveOffers()
}
}

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
override def executorHeartbeatReceived(
execId: String,
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
blockManagerId: BlockManagerId): Boolean = {
// (taskId, stageId, stageAttemptId, accumUpdates)
val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
accumUpdates.flatMap { case (id, updates) =>
val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos)
}
}
}
dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId)
}

def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {
taskSetManager.handleTaskGettingResult(tid)
}

def handleSuccessfulTask(
taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]): Unit = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}

def handleFailedTask(
taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: TaskEndReason): Unit = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
// Need to revive offers again now that the task set manager state has been updated to
// reflect failed tasks that need to be re-run.
backend.reviveOffers()
}
}

def error(message: String) {
synchronized {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
for {
attempts <- taskSetsByStageIdAndAttempt.values
manager <- attempts.values
} {
try {
manager.abort(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
// No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
throw new SparkException(s"Exiting due to error from cluster scheduler: $message")
}
}
}

override def stop() {
speculationScheduler.shutdown()
if (backend != null) {
backend.stop()
}
if (taskResultGetter != null) {
taskResultGetter.stop()
}
starvationTimer.cancel()
}

override def defaultParallelism(): Int = backend.defaultParallelism()

// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
shouldRevive = rootPool.checkSpeculatableTasks()
}
if (shouldRevive) {
backend.reviveOffers()
}
}

override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {
var failedExecutor: Option[String] = None

synchronized {
if (executorIdToTaskCount.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
logExecutorLoss(executorId, hostPort, reason)
removeExecutor(executorId, reason)
failedExecutor = Some(executorId)
} else {
executorIdToHost.get(executorId) match {
case Some(hostPort) =>
// If the host mapping still exists, it means we don't know the loss reason for the
// executor. So call removeExecutor() to update tasks running on that executor when
// the real loss reason is finally known.
logExecutorLoss(executorId, hostPort, reason)
removeExecutor(executorId, reason)

case None =>
// We may get multiple executorLost() calls with different loss reasons. For example,
// one may be triggered by a dropped connection from the slave while another may be a
// report of executor termination from Mesos. We produce log messages for both so we
// eventually report the termination reason.
logError(s"Lost an executor $executorId (already removed): $reason")
}
}
}
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor.isDefined) {
dagScheduler.executorLost(failedExecutor.get, reason)
backend.reviveOffers()
}
}

private def logExecutorLoss(
executorId: String,
hostPort: String,
reason: ExecutorLossReason): Unit = reason match {
case LossReasonPending =>
logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.")
case ExecutorKilled =>
logInfo(s"Executor $executorId on $hostPort killed by driver.")
case _ =>
logError(s"Lost executor $executorId on $hostPort: $reason")
}

/**
* Remove an executor from all our data structures and mark it as lost. If the executor's loss
* reason is not yet known, do not yet remove its association with its host nor update the status
* of any running tasks, since the loss reason defines whether we'll fail those tasks.
*/
private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
executorIdToTaskCount -= executorId

val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) {
hosts -= host
if (hosts.isEmpty) {
hostsByRack -= rack
}
}
}

if (reason != LossReasonPending) {
executorIdToHost -= executorId
rootPool.executorLost(executorId, host, reason)
}
}

def executorAdded(execId: String, host: String) {
dagScheduler.executorAdded(execId, host)
}

def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
executorsByHost.get(host).map(_.toSet)
}

def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
executorsByHost.contains(host)
}

def hasHostAliveOnRack(rack: String): Boolean = synchronized {
hostsByRack.contains(rack)
}

def isExecutorAlive(execId: String): Boolean = synchronized {
executorIdToTaskCount.contains(execId)
}

def isExecutorBusy(execId: String): Boolean = synchronized {
executorIdToTaskCount.getOrElse(execId, -1) > 0
}

// By default, rack is unknown
def getRackForHost(value: String): Option[String] = None

private def waitBackendReady(): Unit = {
if (backend.isReady) {
return
}
while (!backend.isReady) {
// Might take a while for backend to be ready if it is waiting on resources.
if (sc.stopped.get) {
// For example: the master removes the application for some reason
throw new IllegalStateException("Spark context stopped while waiting for backend")
}
synchronized {
this.wait(100)
}
}
}

override def applicationId(): String = backend.applicationId()

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
for {
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
manager <- attempts.get(stageAttemptId)
} yield {
manager
}
}

}

private[spark] object TaskSchedulerImpl {
/**
* Used to balance containers across hosts.
*
* Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
* resource offers representing the order in which the offers should be used.  The resource
* offers are ordered such that we'll allocate one container on each host before allocating a
* second container on any host, and so on, in order to reduce the damage if a host fails.
*
* For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
* [o1, o5, o4, 02, o6, o3]
*/
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
val _keyList = new ArrayBuffer[K](map.size)
_keyList ++= map.keys

// order keyList based on population of value in map
val keyList = _keyList.sortWith(
(left, right) => map(left).size > map(right).size
)

val retval = new ArrayBuffer[T](keyList.size * 2)
var index = 0
var found = true

while (found) {
found = false
for (key <- keyList) {
val containerList: ArrayBuffer[T] = map.getOrElse(key, null)
assert(containerList != null)
// Get the index'th entry for this host - if present
if (index < containerList.size) {
retval += containerList.apply(index)
found = true
}
}
index += 1
}

retval.toList
}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐