style: rename functions in root.go (#623)
* replace isFlagSet with Changed * style: rename mustGetStringViperFlag and getStringViperFlag, use getParamB to read noauth * style * style: move * fix build error * rename getServerWithViper to getRunParams Former-commit-id: 2d6d9535247e9de01ca0741726665f7dfef1b1e6 [formerly 8fa7b76b92545e0b91bd06fd7b21247e921bbb2a] [formerly 3b7eefee2ac5120796c2a13583ea0b0b2d1ccb89 [formerly 2c526077c05b298a838b4468881e2c460369d5bc]] Former-commit-id: 57912c62749c5890285657c1b4b637925d480ccd [formerly 03b4cd9551d77589f25881cbf693c6bf8390db4e] Former-commit-id: 77b9083e39607aedcba9ef9b4af9a94535661ffd
This commit is contained in:
		
							parent
							
								
									cc428d3cd6
								
							
						
					
					
						commit
						da7d1db06c
					
				
							
								
								
									
										116
									
								
								cmd/root.go
								
								
								
								
							
							
						
						
									
										116
									
								
								cmd/root.go
								
								
								
								
							| 
						 | 
					@ -54,43 +54,6 @@ func addServerFlags(flags *pflag.FlagSet) {
 | 
				
			||||||
	flags.StringP("baseurl", "b", "", "base url")
 | 
						flags.StringP("baseurl", "b", "", "base url")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func isFlagSet(flags *pflag.FlagSet, key string) bool {
 | 
					 | 
				
			||||||
	set:= false
 | 
					 | 
				
			||||||
	flags.Visit(func(flag *pflag.Flag) {
 | 
					 | 
				
			||||||
		if flag.Name == key {
 | 
					 | 
				
			||||||
			set = true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	return set
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// NOTE: we could simply bind the flags to viper and use IsSet.
 | 
					 | 
				
			||||||
// Although there is a bug on Viper that always returns true on IsSet
 | 
					 | 
				
			||||||
// if a flag is binded. Our alternative way is to manually check
 | 
					 | 
				
			||||||
// the flag and then the value from env/config/gotten by viper.
 | 
					 | 
				
			||||||
// https://github.com/spf13/viper/pull/331
 | 
					 | 
				
			||||||
func getStringViperFlag(flags *pflag.FlagSet, key string) (string, bool) {
 | 
					 | 
				
			||||||
	value, _ := flags.GetString(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If set on Flags, use it.
 | 
					 | 
				
			||||||
	if isFlagSet(flags, key) {
 | 
					 | 
				
			||||||
		return value, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If set through viper (env, config), return it.
 | 
					 | 
				
			||||||
	if v.IsSet(key) {
 | 
					 | 
				
			||||||
		return v.GetString(key), true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Otherwise use default value on flags.
 | 
					 | 
				
			||||||
	return value, false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func mustGetStringViperFlag(flags *pflag.FlagSet, key string) string {
 | 
					 | 
				
			||||||
	val, _ := getStringViperFlag(flags, key)
 | 
					 | 
				
			||||||
	return val
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var rootCmd = &cobra.Command{
 | 
					var rootCmd = &cobra.Command{
 | 
				
			||||||
	Use:     "filebrowser",
 | 
						Use:     "filebrowser",
 | 
				
			||||||
	Version: version.Version,
 | 
						Version: version.Version,
 | 
				
			||||||
| 
						 | 
					@ -137,7 +100,7 @@ user created with the credentials from options "username" and "password".`,
 | 
				
			||||||
			quickSetup(cmd.Flags(), d)
 | 
								quickSetup(cmd.Flags(), d)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		server := getServerWithViper(cmd.Flags(), d.store)
 | 
							server := getRunParams(cmd.Flags(), d.store)
 | 
				
			||||||
		setupLog(server.Log)
 | 
							setupLog(server.Log)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		root, err := filepath.Abs(server.Root)
 | 
							root, err := filepath.Abs(server.Root)
 | 
				
			||||||
| 
						 | 
					@ -168,41 +131,70 @@ user created with the credentials from options "username" and "password".`,
 | 
				
			||||||
	}, pythonConfig{allowNoDB: true}),
 | 
						}, pythonConfig{allowNoDB: true}),
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getServerWithViper(flags *pflag.FlagSet, st *storage.Storage) *settings.Server {
 | 
					func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server {
 | 
				
			||||||
	server, err := st.Settings.GetServer()
 | 
						server, err := st.Settings.GetServer()
 | 
				
			||||||
	checkErr(err)
 | 
						checkErr(err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "root"); set {
 | 
						if val, set := getParamB(flags, "root"); set {
 | 
				
			||||||
		server.Root = val
 | 
							server.Root = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "baseurl"); set {
 | 
						if val, set := getParamB(flags, "baseurl"); set {
 | 
				
			||||||
		server.BaseURL = val
 | 
							server.BaseURL = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "address"); set {
 | 
						if val, set := getParamB(flags, "address"); set {
 | 
				
			||||||
		server.Address = val
 | 
							server.Address = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "port"); set {
 | 
						if val, set := getParamB(flags, "port"); set {
 | 
				
			||||||
		server.Port = val
 | 
							server.Port = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "log"); set {
 | 
						if val, set := getParamB(flags, "log"); set {
 | 
				
			||||||
		server.Log = val
 | 
							server.Log = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "key"); set {
 | 
						if val, set := getParamB(flags, "key"); set {
 | 
				
			||||||
		server.TLSKey = val
 | 
							server.TLSKey = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, set := getStringViperFlag(flags, "cert"); set {
 | 
						if val, set := getParamB(flags, "cert"); set {
 | 
				
			||||||
		server.TLSCert = val
 | 
							server.TLSCert = val
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return server
 | 
						return server
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getParamB returns a parameter as a string and a boolean to tell if it is different from the default
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// NOTE: we could simply bind the flags to viper and use IsSet.
 | 
				
			||||||
 | 
					// Although there is a bug on Viper that always returns true on IsSet
 | 
				
			||||||
 | 
					// if a flag is binded. Our alternative way is to manually check
 | 
				
			||||||
 | 
					// the flag and then the value from env/config/gotten by viper.
 | 
				
			||||||
 | 
					// https://github.com/spf13/viper/pull/331
 | 
				
			||||||
 | 
					func getParamB(flags *pflag.FlagSet, key string) (string, bool) {
 | 
				
			||||||
 | 
						value, _ := flags.GetString(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If set on Flags, use it.
 | 
				
			||||||
 | 
						if flags.Changed(key) {
 | 
				
			||||||
 | 
							return value, true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If set through viper (env, config), return it.
 | 
				
			||||||
 | 
						if v.IsSet(key) {
 | 
				
			||||||
 | 
							return v.GetString(key), true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Otherwise use default value on flags.
 | 
				
			||||||
 | 
						return value, false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getParam(flags *pflag.FlagSet, key string) string {
 | 
				
			||||||
 | 
						val, _ := getParamB(flags, key)
 | 
				
			||||||
 | 
						return val
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func setupLog(logMethod string) {
 | 
					func setupLog(logMethod string) {
 | 
				
			||||||
	switch logMethod {
 | 
						switch logMethod {
 | 
				
			||||||
	case "stdout":
 | 
						case "stdout":
 | 
				
			||||||
| 
						 | 
					@ -223,8 +215,8 @@ func setupLog(logMethod string) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func quickSetup(flags *pflag.FlagSet, d pythonData) {
 | 
					func quickSetup(flags *pflag.FlagSet, d pythonData) {
 | 
				
			||||||
	set := &settings.Settings{
 | 
						set := &settings.Settings{
 | 
				
			||||||
		Key:        generateRandomBytes(64), // 256 bit
 | 
							Key:    generateRandomBytes(64), // 256 bit
 | 
				
			||||||
		Signup:     false,
 | 
							Signup: false,
 | 
				
			||||||
		Defaults: settings.UserDefaults{
 | 
							Defaults: settings.UserDefaults{
 | 
				
			||||||
			Scope:  ".",
 | 
								Scope:  ".",
 | 
				
			||||||
			Locale: "en",
 | 
								Locale: "en",
 | 
				
			||||||
| 
						 | 
					@ -241,14 +233,8 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) {
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	noauth, err := flags.GetBool("noauth")
 | 
						var err error
 | 
				
			||||||
	checkErr(err)
 | 
						if _, noauth := getParamB(flags, "noauth"); noauth {
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !isFlagSet(flags, "noauth") && v.IsSet("noauth") {
 | 
					 | 
				
			||||||
		noauth = v.GetBool("noauth")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if noauth {
 | 
					 | 
				
			||||||
		set.AuthMethod = auth.MethodNoAuth
 | 
							set.AuthMethod = auth.MethodNoAuth
 | 
				
			||||||
		err = d.store.Auth.Save(&auth.NoAuth{})
 | 
							err = d.store.Auth.Save(&auth.NoAuth{})
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
| 
						 | 
					@ -261,20 +247,20 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) {
 | 
				
			||||||
	checkErr(err)
 | 
						checkErr(err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ser := &settings.Server{
 | 
						ser := &settings.Server{
 | 
				
			||||||
		BaseURL: mustGetStringViperFlag(flags, "baseurl"),
 | 
							BaseURL: getParam(flags, "baseurl"),
 | 
				
			||||||
		Port:    mustGetStringViperFlag(flags, "port"),
 | 
							Port:    getParam(flags, "port"),
 | 
				
			||||||
		Log:     mustGetStringViperFlag(flags, "log"),
 | 
							Log:     getParam(flags, "log"),
 | 
				
			||||||
		TLSKey:  mustGetStringViperFlag(flags, "key"),
 | 
							TLSKey:  getParam(flags, "key"),
 | 
				
			||||||
		TLSCert: mustGetStringViperFlag(flags, "cert"),
 | 
							TLSCert: getParam(flags, "cert"),
 | 
				
			||||||
		Address: mustGetStringViperFlag(flags, "address"),
 | 
							Address: getParam(flags, "address"),
 | 
				
			||||||
		Root:    mustGetStringViperFlag(flags, "root"),
 | 
							Root:    getParam(flags, "root"),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = d.store.Settings.SaveServer(ser)
 | 
						err = d.store.Settings.SaveServer(ser)
 | 
				
			||||||
	checkErr(err)
 | 
						checkErr(err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	username := mustGetStringViperFlag(flags, "username")
 | 
						username := getParam(flags, "username")
 | 
				
			||||||
	password := mustGetStringViperFlag(flags, "password")
 | 
						password := getParam(flags, "password")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if password == "" {
 | 
						if password == "" {
 | 
				
			||||||
		password, err = users.HashPwd("admin")
 | 
							password, err = users.HashPwd("admin")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,7 +25,7 @@ this version.`,
 | 
				
			||||||
		flags := cmd.Flags()
 | 
							flags := cmd.Flags()
 | 
				
			||||||
		oldDB := mustGetString(flags, "old.database")
 | 
							oldDB := mustGetString(flags, "old.database")
 | 
				
			||||||
		oldConf := mustGetString(flags, "old.config")
 | 
							oldConf := mustGetString(flags, "old.config")
 | 
				
			||||||
		err := importer.Import(oldDB, oldConf, mustGetStringViperFlag(flags, "database"))
 | 
							err := importer.Import(oldDB, oldConf, getParam(flags, "database"))
 | 
				
			||||||
		checkErr(err)
 | 
							checkErr(err)
 | 
				
			||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,7 +66,7 @@ func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
 | 
				
			||||||
	return func(cmd *cobra.Command, args []string) {
 | 
						return func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
		data := pythonData{hadDB: true}
 | 
							data := pythonData{hadDB: true}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		path := mustGetStringViperFlag(cmd.Flags(), "database")
 | 
							path := getParam(cmd.Flags(), "database")
 | 
				
			||||||
		_, err := os.Stat(path)
 | 
							_, err := os.Stat(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if os.IsNotExist(err) {
 | 
							if os.IsNotExist(err) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue