diff --git a/mlsh_code/master.py b/mlsh_code/master.py index 9249f76..23ef4d3 100644 --- a/mlsh_code/master.py +++ b/mlsh_code/master.py @@ -41,7 +41,8 @@ def start(callback, args, workerseed, rank, comm): rollout = rollouts.traj_segment_generator(policy, sub_policies, env, macro_duration, num_rollouts, stochastic=True, args=args) - for x in range(10000): + niter = 10000 + for x in range(niter): callback(x) if x == 0: learner.syncSubpolicies() @@ -56,7 +57,12 @@ def start(callback, args, workerseed, rank, comm): env.env.realgoal = shared_goal print("It is iteration %d so i'm changing the goal to %s" % (x, env.env.realgoal)) - mini_ep = 0 if x > 0 else -1 * (rank % 10)*int(warmup_time+train_time / 10) + if x == 0: + mini_ep = -1 * (rank % 10)*int((warmup_time+train_time) / 10) + elif x == niter - 1: + mini_ep = (rank % 10)*int((warmup_time+train_time) / 10) + else: + mini_ep = 0 # mini_ep = 0 totalmeans = []