Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ contrib / playground / AnyExpOnTerm.py






import argparse
import json
import os

import caffe2.contrib.playground.AnyExp as AnyExp
import caffe2.contrib.playground.checkpoint as checkpoint

import logging
logging.basicConfig()
log = logging.getLogger("AnyExpOnTerm")
log.setLevel(logging.DEBUG)


def runShardedTrainLoop(opts, myTrainFun):
    start_epoch = 0
    pretrained_model = opts['model_param']['pretrained_model']
    if pretrained_model != '' and os.path.exists(pretrained_model):
        # Only want to get start_epoch.
        start_epoch, prev_checkpointed_lr, best_metric = \
            checkpoint.initialize_params_from_file(
                model=None,
                weights_file=pretrained_model,
                num_xpus=1,
                opts=opts,
                broadcast_computed_param=True,
                reset_epoch=opts['model_param']['reset_epoch'],
            )
    log.info('start epoch: {}'.format(start_epoch))
    pretrained_model = None if pretrained_model == '' else pretrained_model
    ret = None

    pretrained_model = ""
    shard_results = []

    for epoch in range(start_epoch,
                       opts['epoch_iter']['num_epochs'],
                       opts['epoch_iter']['num_epochs_per_flow_schedule']):
        # must support checkpoint or the multiple schedule will always
        # start from initial state
        checkpoint_model = None if epoch == start_epoch else ret['model']
        pretrained_model = None if epoch > start_epoch else pretrained_model
        shard_results = []
        # with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False):
        for shard_id in range(opts['distributed']['num_shards']):
            opts['temp_var']['shard_id'] = shard_id
            opts['temp_var']['pretrained_model'] = pretrained_model
            opts['temp_var']['checkpoint_model'] = checkpoint_model
            opts['temp_var']['epoch'] = epoch
            opts['temp_var']['start_epoch'] = start_epoch
            shard_ret = myTrainFun(opts)
            shard_results.append(shard_ret)

        ret = None
        # always only take shard_0 return
        for shard_ret in shard_results:
            if shard_ret is not None:
                ret = shard_ret
                opts['temp_var']['metrics_output'] = ret['metrics']
                break
        log.info('ret is: {}'.format(str(ret)))

    return ret


def trainFun():
    def simpleTrainFun(opts):
        trainerClass = AnyExp.createTrainerClass(opts)
        trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts)
        trainer = trainerClass(opts)
        return trainer.buildModelAndTrain(opts)
    return simpleTrainFun


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Any Experiment training.')
    parser.add_argument("--parameters-json", type=json.loads,
                        help='model options in json format', dest="params")

    args = parser.parse_args()
    opts = args.params['opts']
    opts = AnyExp.initOpts(opts)
    log.info('opts is: {}'.format(str(opts)))

    AnyExp.initDefaultModuleMap()

    opts['input']['datasets'] = AnyExp.aquireDatasets(opts)

    # defined this way so that AnyExp.trainFun(opts) can be replaced with
    # some other custermized training function.
    ret = runShardedTrainLoop(opts, trainFun())

    log.info('ret is: {}'.format(str(ret)))