Skip to content

Commit

Permalink
Await when adding account (#17664)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra authored Apr 27, 2023
1 parent b3fe176 commit a9ee4f6
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 67 deletions.
3 changes: 3 additions & 0 deletions localization/xliff/enu/constants/localizedConstants.enu.xlf
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@
<trans-unit id="msgPromptFirewallRuleCreated">
<source xml:lang="en">Firewall rule successfully created.</source>
</trans-unit>
<trans-unit id="msgAuthTypeNotFound">
<source xml:lang="en">Failed to get authentication method, please remove and re-add the account.</source>
</trans-unit>
<trans-unit id="msgAccountNotFound">
<source xml:lang="en">Account not found</source>
</trans-unit>
Expand Down
5 changes: 4 additions & 1 deletion src/azure/adal/adalAzureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { AzureUserInteraction } from './azureUserInteraction';
import { StorageService } from './storageService';

export class AdalAzureController extends AzureController {

private _authMappings = new Map<AzureAuthType, AzureAuth>();
private cacheProvider: SimpleTokenCache;
private storageService: StorageService;
Expand Down Expand Up @@ -54,6 +53,10 @@ export class AdalAzureController extends AzureController {
return response ? response as IAccount : undefined;
}

public isAccountInCache(account: IAccount): Promise<boolean> {
throw new Error('Method not implemented.');
}

public async getAccountSecurityToken(account: IAccount, tenantId: string, settings: IAADResource): Promise<IToken | undefined> {
let token: IToken | undefined;
let azureAuth = await this.getAzureAuthInstance(getAzureActiveDirectoryConfig());
Expand Down
2 changes: 2 additions & 0 deletions src/azure/azureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ export abstract class AzureController {
public abstract refreshAccessToken(account: IAccount, accountStore: AccountStore,
tenantId: string | undefined, settings: IAADResource): Promise<IToken | undefined>;

public abstract isAccountInCache(account: IAccount): Promise<boolean>;

public abstract removeAccount(account: IAccount): Promise<void>;

public abstract handleAuthMapping(): void;
Expand Down
47 changes: 26 additions & 21 deletions src/azure/msal/msalAzureAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,35 @@ export abstract class MsalAzureAuth {
}

public async refreshAccessToken(account: IAccount, tenantId: string, settings: IAADResource): Promise<IAccount | undefined> {
try {
const tokenResult = await this.getToken(account, tenantId, settings);
if (!tokenResult) {
account.isStale = true;
return account;
}
if (account) {
try {
const tokenResult = await this.getToken(account, tenantId, settings);
if (!tokenResult) {
account.isStale = true;
return account;
}

const tokenClaims = this.getTokenClaims(tokenResult.accessToken);
if (!tokenClaims) {
account.isStale = true;
return account;
}
const tokenClaims = this.getTokenClaims(tokenResult.accessToken);
if (!tokenClaims) {
account.isStale = true;
return account;
}

const token: IToken = {
key: tokenResult.account!.homeAccountId,
token: tokenResult.accessToken,
tokenType: tokenResult.tokenType,
expiresOn: tokenResult.account!.idTokenClaims!.exp
};
const token: IToken = {
key: tokenResult.account!.homeAccountId,
token: tokenResult.accessToken,
tokenType: tokenResult.tokenType,
expiresOn: tokenResult.account!.idTokenClaims!.exp
};

return await this.hydrateAccount(token, tokenClaims);
} catch (ex) {
account.isStale = true;
throw ex;
return await this.hydrateAccount(token, tokenClaims);
} catch (ex) {
account.isStale = true;
throw ex;
}
} else {
this.logger.error(`refreshAccessToken: Account not received for refreshing access token.`);
throw Error(LocalizedConstants.msgAccountNotFound);
}
}

Expand Down
19 changes: 16 additions & 3 deletions src/azure/msal/msalAzureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ export class MsalAzureController extends AzureController {
return response ? response as IAccount : undefined;
}

public async isAccountInCache(account: IAccount): Promise<boolean> {
let authType = getAzureActiveDirectoryConfig();
let azureAuth = await this.getAzureAuthInstance(authType!);
await this.clearOldCacheIfExists();
let accountInfo = await azureAuth.getAccountFromMsalCache(account.key.id);
return accountInfo !== undefined;
}

private async getAzureAuthInstance(authType: AzureAuthType): Promise<MsalAzureAuth | undefined> {
if (!this._authMappings.has(authType)) {
await this.handleAuthMapping();
Expand All @@ -113,9 +121,14 @@ export class MsalAzureController extends AzureController {
return token;
}
} else {
account.isStale = true;
this.logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`);
throw Error('Failed to get authentication method, please remove and re-add the account');
if (account) {
account.isStale = true;
this.logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`);
throw Error(LocalizedConstants.msgAuthTypeNotFound);
} else {
this.logger.error(`_getAccountSecurityToken: Authentication method not found as account not available.`);
throw Error(LocalizedConstants.msgAccountNotFound);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/controllers/mainController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export default class MainController implements vscode.Disposable {
this.registerCommand(Constants.cmdAadRemoveAccount);
this._event.on(Constants.cmdAadRemoveAccount, () => this.removeAadAccount(this._prompter));
this.registerCommand(Constants.cmdAadAddAccount);
this._event.on(Constants.cmdAadAddAccount, () => this.addAddAccount());
this._event.on(Constants.cmdAadAddAccount, () => this.addAadAccount());

this.initializeObjectExplorer();

Expand Down Expand Up @@ -1241,7 +1241,7 @@ export default class MainController implements vscode.Disposable {
this.connectionManager.removeAccount(prompter);
}

public addAddAccount(): void {
public addAadAccount(): void {
this.connectionManager.addAccount();
}
}
38 changes: 21 additions & 17 deletions src/models/connectionProfile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,39 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
type: QuestionTypes.confirm,
name: LocalizedConstants.msgSavePassword,
message: LocalizedConstants.msgSavePassword,
shouldPrompt: (answers) => !profile.connectionString && ConnectionCredentials.isPasswordBasedCredential(profile),
shouldPrompt: () => !profile.connectionString && ConnectionCredentials.isPasswordBasedCredential(profile),
onAnswered: (value) => profile.savePassword = value
},
{
type: QuestionTypes.expand,
name: LocalizedConstants.aad,
message: LocalizedConstants.azureChooseAccount,
choices: azureAccountChoices,
shouldPrompt: (answers) => profile.isAzureActiveDirectory(),
onAnswered: (value) => {
shouldPrompt: () => profile.isAzureActiveDirectory(),
onAnswered: async (value) => {
accountAnswer = value;
if (value !== 'addAccount') {
let account: IAccount = value;
let account = value;
profile.accountId = account?.key.id;
tenantChoices.push(...account?.properties?.tenants.map(t => ({ name: t.displayName, value: t })));
tenantChoices.push(...account?.properties?.tenants!.map(t => ({ name: t.displayName, value: t })));
if (tenantChoices.length === 1) {
profile.tenantId = tenantChoices[0].value.id;
}
try {
profile = await azureController.refreshTokenWrapper(profile, accountStore, accountAnswer, providerSettings.resources.databaseResource);
} catch (error) {
console.log(`Refreshing tokens failed: ${error}`);
}
} else {
try {
profile = await azureController.populateAccountProperties(profile, accountStore, providerSettings.resources.databaseResource);
if (profile) {
vscode.window.showInformationMessage(utils.formatString(LocalizedConstants.accountAddedSuccessfully, profile.email));
}
} catch (e) {
console.error(`Could not add account: ${e}`);
vscode.window.showErrorMessage(e);
}
}
}
},
Expand All @@ -111,7 +126,7 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
default: defaultProfileValues ? defaultProfileValues.tenantId : undefined,
// Need not prompt for tenant question when 'Sql Authentication Provider' is enabled,
// since tenant information is received from Server with authority URI in the Login flow.
shouldPrompt: (answers) => profile.isAzureActiveDirectory() && tenantChoices.length > 1 && !getEnableSqlAuthenticationProviderConfig(),
shouldPrompt: () => profile.isAzureActiveDirectory() && tenantChoices.length > 1 && !getEnableSqlAuthenticationProviderConfig(),
onAnswered: (value: ITenant) => {
profile.tenantId = value.id;
}
Expand All @@ -130,17 +145,6 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
});

return prompter.prompt(questions, true).then(async answers => {
if (answers?.authenticationType === 'AzureMFA') {
if (answers.AAD === 'addAccount') {
profile = await azureController.populateAccountProperties(profile, accountStore, providerSettings.resources.databaseResource);
} else {
try {
profile = await azureController.refreshTokenWrapper(profile, accountStore, accountAnswer, providerSettings.resources.databaseResource);
} catch (error) {
console.log(`Refreshing tokens failed: ${error}`);
}
}
}
if (answers && profile.isValidProfile()) {
return profile;
}
Expand Down
7 changes: 6 additions & 1 deletion src/objectExplorer/objectExplorerService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,16 @@ export class ObjectExplorerService {
let azureController = this._connectionManager.azureController;
let account = this._connectionManager.accountStore.getAccount(connectionCredentials.accountId);
let profile = new ConnectionProfile(connectionCredentials);
let needsRefresh: boolean = false;
if (azureController.isSqlAuthProviderEnabled()) {
this._client.logger.verbose('SQL Authentication provider is enabled for Azure MFA connections, skipping token acquiry in extension.');
connectionCredentials.user = account.displayInfo.displayName;
connectionCredentials.email = account.displayInfo.email;
} else if (!connectionCredentials.azureAccountToken) {
if (!azureController.isAccountInCache(account)) {
needsRefresh = true;
}
}
if (!connectionCredentials.azureAccountToken && (!azureController.isSqlAuthProviderEnabled() || needsRefresh)) {
let azureAccountToken = await azureController.refreshAccessToken(
account, this._connectionManager.accountStore, connectionCredentials.tenantId, providerSettings.resources.databaseResource);
if (!azureAccountToken) {
Expand Down
4 changes: 2 additions & 2 deletions src/prompts/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ export default class CodeAdapter implements IPrompter {
// }

if (!question.shouldPrompt || question.shouldPrompt(answers) === true) {
return prompt.render().then(result => {
return prompt.render().then(async result => {
answers[question.name] = result;

if (question.onAnswered) {
question.onAnswered(result);
await question.onAnswered(result);
}
return answers;
});
Expand Down
2 changes: 1 addition & 1 deletion src/prompts/question.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export interface IQuestion {
// Optional pre-prompt function. Takes in set of answers so far, and returns true if prompt should occur
shouldPrompt?: (answers: { [id: string]: any }) => boolean;
// Optional action to take on the question being answered
onAnswered?: (value: any) => void;
onAnswered?: (value: any) => void | Promise<void>;
// Optional set of options to support matching choices.
matchOptions?: vscode.QuickPickOptions;
}
Expand Down
37 changes: 18 additions & 19 deletions src/views/connectionUI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -615,27 +615,26 @@ export class ConnectionUI {
}

private async createFirewallRule(serverName: string, ipAddress: string): Promise<boolean> {
return this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn,
LocalizedConstants.createFirewallRuleLabel).then(async (result) => {
if (result === LocalizedConstants.createFirewallRuleLabel) {
const firewallService = this.connectionManager.firewallService;
let ipRange = await this.promptForIpAddress(ipAddress);
if (ipRange) {
let firewallResult = await firewallService.createFirewallRule(serverName, ipRange.startIpAddress, ipRange.endIpAddress);
if (firewallResult.result) {
this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptFirewallRuleCreated);
return true;
} else {
Utils.showErrorMsg(firewallResult.errorMessage);
return false;
}
} else {
return false;
}
let result = await this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn,
LocalizedConstants.createFirewallRuleLabel);
if (result === LocalizedConstants.createFirewallRuleLabel) {
const firewallService = this.connectionManager.firewallService;
let ipRange = await this.promptForIpAddress(ipAddress);
if (ipRange) {
let firewallResult = await firewallService.createFirewallRule(serverName, ipRange.startIpAddress, ipRange.endIpAddress);
if (firewallResult.result) {
this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptFirewallRuleCreated);
return true;
} else {
Utils.showErrorMsg(firewallResult.errorMessage);
return false;
}
});
} else {
return false;
}
} else {
return false;
}
}

private promptForRetryConnectWithDifferentCredentials(): PromiseLike<boolean> {
Expand Down Expand Up @@ -672,7 +671,7 @@ export class ConnectionUI {
}

public async addNewAccount(): Promise<IAccount> {
return this.connectionManager.azureController.addAccount(this._accountStore);
return await this.connectionManager.azureController.addAccount(this._accountStore);
}

// Prompts the user to pick a profile for removal, then removes from the global saved state
Expand Down

0 comments on commit a9ee4f6

Please sign in to comment.